use context to graceful shutdown

master
dingjun 6 years ago
parent 96e0332d16
commit 51e00f2c99

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -10,6 +11,7 @@ import (
"os/signal" "os/signal"
"strings" "strings"
"syscall" "syscall"
"time"
glog "github.com/fangdingjun/go-log" glog "github.com/fangdingjun/go-log"
proxyproto "github.com/pires/go-proxyproto" proxyproto "github.com/pires/go-proxyproto"
@ -59,7 +61,7 @@ func getSNIServerName(buf []byte) string {
return msg.serverName return msg.serverName
} }
func forward(c net.Conn, data []byte, dst string) { func forward(ctx context.Context, c net.Conn, data []byte, dst string) {
addr := dst addr := dst
proxyProto := 0 proxyProto := 0
@ -90,8 +92,8 @@ func forward(c net.Conn, data []byte, dst string) {
hdr.Version = 2 hdr.Version = 2
} }
} }
dialer := &net.Dialer{Timeout: 10 * time.Second}
c1, err := net.Dial("tcp", addr) c1, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
glog.Error(err) glog.Error(err)
return return
@ -101,7 +103,10 @@ func forward(c net.Conn, data []byte, dst string) {
if proxyProto != 0 { if proxyProto != 0 {
glog.Debugf("send proxy proto v%d to %s", proxyProto, addr) glog.Debugf("send proxy proto v%d to %s", proxyProto, addr)
hdr.WriteTo(c1) if _, err = hdr.WriteTo(c1); err != nil {
glog.Errorln(err)
return
}
} }
if _, err = c1.Write(data); err != nil { if _, err = c1.Write(data); err != nil {
@ -121,7 +126,10 @@ func forward(c net.Conn, data []byte, dst string) {
ch <- struct{}{} ch <- struct{}{}
}() }()
<-ch select {
case <-ch:
case <-ctx.Done():
}
} }
func getDST(c net.Conn, name string) string { func getDST(c net.Conn, name string) string {
@ -134,7 +142,7 @@ func getDefaultDST() string {
return cfg.Default return cfg.Default
} }
func serve(c net.Conn) { func serve(ctx context.Context, c net.Conn) {
defer c.Close() defer c.Close()
buf := make([]byte, 1024) buf := make([]byte, 1024)
@ -146,7 +154,7 @@ func serve(c net.Conn) {
servername := getSNIServerName(buf[:n]) servername := getSNIServerName(buf[:n])
if servername == "" { if servername == "" {
glog.Debugf("no sni, send to default") glog.Debugf("no sni, send to default")
forward(c, buf[:n], getDefaultDST()) forward(ctx, c, buf[:n], getDefaultDST())
return return
} }
dst := getDST(c, servername) dst := getDST(c, servername)
@ -154,7 +162,7 @@ func serve(c net.Conn) {
dst = getDefaultDST() dst = getDefaultDST()
glog.Debugf("use default dst %s for sni %s", dst, servername) glog.Debugf("use default dst %s for sni %s", dst, servername)
} }
forward(c, buf[:n], dst) forward(ctx, c, buf[:n], dst)
} }
var cfg conf var cfg conf
@ -188,28 +196,34 @@ func main() {
glog.Default.Level = lv glog.Default.Level = lv
} }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, d := range cfg.Listen { for _, d := range cfg.Listen {
glog.Infof("listen on :%d", d) glog.Infof("listen on :%d", d)
l, err := net.Listen("tcp", fmt.Sprintf(":%d", d)) lc := &net.ListenConfig{}
l, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", d))
if err != nil { if err != nil {
glog.Fatal(err) glog.Fatal(err)
} }
go func(l net.Listener) { go func(ctx context.Context, l net.Listener) {
defer l.Close() defer l.Close()
for { for {
c1, err := l.Accept() c1, err := l.Accept()
if err != nil { if err != nil {
glog.Fatal(err) glog.Error(err)
break
} }
go serve(c1) go serve(ctx, c1)
} }
}(l) }(ctx, l)
} }
ch := make(chan os.Signal, 2) ch := make(chan os.Signal, 2)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
select { select {
case s := <-ch: case s := <-ch:
cancel()
glog.Printf("received signal %s, exit.", s) glog.Printf("received signal %s, exit.", s)
} }
} }

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"io/ioutil" "io/ioutil"
"net" "net"
@ -36,7 +37,7 @@ func TestProxyProto(t *testing.T) {
log.Errorln(err) log.Errorln(err)
return return
} }
go serve(conn) go serve(context.Background(), conn)
} }
}() }()
cert, err := tls.LoadX509KeyPair("server.crt", "server.key") cert, err := tls.LoadX509KeyPair("server.crt", "server.key")

Loading…
Cancel
Save