diff --git a/handler.go b/handler.go index 88baa4c..893f3d9 100644 --- a/handler.go +++ b/handler.go @@ -7,19 +7,21 @@ import ( "fmt" "io" "io/ioutil" - "log" "net" "net/http" "net/url" "strings" + + "github.com/fangdingjun/go-log" ) func initHandler() { var _certs []tls.Certificate for _, _c := range _config.Certificate { + log.Debugf("load certificate %s, %s...", _c.Cert, _c.Key) _cert, err := _c.load() if err != nil { - log.Println("load certificate failed", err) + log.Errorln("load certificate failed", err) continue } _certs = append(_certs, _cert) @@ -41,7 +43,7 @@ func initHandler() { for { c, err := l.Accept() if err != nil { - log.Printf("accept error: %s", err) + log.Errorf("accept error: %s", err) continue } go handleConnection(c) @@ -51,20 +53,26 @@ func initHandler() { } func handleConnection(c net.Conn) { - //log.Printf("connection from %s", c.RemoteAddr().String()) + defer func() { + log.Debugf("close connection %s", c.RemoteAddr()) + if err := c.Close(); err != nil { + log.Debugf("close connection %s error: %s", c.RemoteAddr(), err) + } + }() + + log.Debugf("connection from %s\n", c.RemoteAddr().String()) tlsconn := c.(*tls.Conn) - defer tlsconn.Close() connstate := tlsconn.ConnectionState() for !connstate.HandshakeComplete { if err := tlsconn.Handshake(); err != nil { - log.Printf("handshake error: %s", err) + log.Errorf("%s tls handshake error: %s", c.RemoteAddr(), err) return } connstate = tlsconn.ConnectionState() } - //log.Printf("handshake complete") + log.Debugf("%s handshake complete", c.RemoteAddr()) servername := connstate.ServerName @@ -79,13 +87,13 @@ func handleConnection(c net.Conn) { if backend == nil { _b, err := url.Parse(_config.DefaultBackend) if err != nil { - log.Printf("parse addr error: %s", err) + log.Errorf("parse addr error: %s", err) return } backend = _b } - log.Printf("connection from %s, tls version 0x%x, sni: %s, forward to: %s\n", + log.Debugf("connection from %s, tls version 0x%x, sni: %s, forward to: %s", c.RemoteAddr().String(), connstate.Version, servername, @@ -99,39 +107,33 @@ func handleForward(c *tls.Conn, b *url.URL) { var remote net.Conn var err error - //log.Printf("forward to %s", b.String()) + log.Debugf("%s forward to %s", c.RemoteAddr(), b.String()) switch b.Scheme { + case "http": + httpForward(c, b) + return case "tcp": + //log.Debugf("forward to tcp backend %s", b.Host) remote, err = net.Dial("tcp", b.Host) case "unix": + //log.Debugf("forward to unix backend %s", b.Host) remote, err = net.Dial("unix", b.Host) - case "http": - if !strings.Contains(b.Host, ":") { - b.Host = fmt.Sprintf("%s:80", b.Host) - } - remote, err = net.Dial("tcp", b.Host) - if err != nil { - log.Printf("dial to %s error: %s", b.Host, err) - writeErrResponse(c, http.StatusBadGateway, err.Error()) - return - } - httpForward(c, remote) - return case "tls": + //log.Debugf("forward to tls backend %s", b.Host) h, _, _ := net.SplitHostPort(b.Host) remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h}) default: - log.Printf("backend type '%s' is not supported", b.Scheme) + log.Errorf("backend type '%s' is not supported", b.Scheme) return } if err != nil { - log.Printf("dail to backend %s error: %s", b.String(), err) + log.Errorf("dail to backend %s error: %s", b.String(), err) return } if remote == nil { - log.Println("remote is nil") + log.Warnln("remote is nil") return } //log.Println("begin data forward") @@ -173,18 +175,29 @@ func writeErrResponse(w io.Writer, status int, content string) { res.Write(w) } -func httpForward(r, b net.Conn) { - defer b.Close() +func httpForward(r net.Conn, b *url.URL) { + //log.Debugf("forward to http backend %s", b.Host) + if !strings.Contains(b.Host, ":") { + b.Host = fmt.Sprintf("%s:80", b.Host) + } + backend, err := net.Dial("tcp", b.Host) + if err != nil { + log.Errorf("dial to %s error: %s", b.Host, err) + // consume the client request, + // or client will get forcely close connection error + http.ReadRequest(bufio.NewReader(r)) + writeErrResponse(r, http.StatusBadGateway, err.Error()) + return + } + defer backend.Close() rb := bufio.NewReader(r) - bb := bufio.NewReader(b) + bb := bufio.NewReader(backend) for { req, err := http.ReadRequest(rb) if err != nil { if err != io.EOF { - log.Printf("read http request error: %s", err) - //fmt.Fprintf(b, "HTTP/1.1 504 Bad gateway\r\nConnection: close\r\n\r\n") - //writeErrResponse(b, http.StatusBadGateway, err.Error()) + log.Errorf("read http request error: %s", err) } return } @@ -195,14 +208,14 @@ func httpForward(r, b net.Conn) { req.Header.Add("X-Real-Ip", addr.IP.String()) //log.Printf("%+v\n", req.Header) - err = req.Write(b) + err = req.Write(backend) if req.Body != nil { req.Body.Close() } if err != nil { - log.Printf("write request to backend error: %s", err) + log.Errorf("write request to backend error: %s", err) writeErrResponse(r, http.StatusBadGateway, err.Error()) return } @@ -210,7 +223,7 @@ func httpForward(r, b net.Conn) { res, err := http.ReadResponse(bb, req) if err != nil { if err != io.EOF { - log.Printf("read http response from backend error: %s", err) + log.Errorf("read http response from backend error: %s", err) } writeErrResponse(r, http.StatusBadGateway, "gateway error") return @@ -223,7 +236,7 @@ func httpForward(r, b net.Conn) { } if err != nil { - log.Printf("write response to client error: %s", err) + log.Errorf("write response to client error: %s", err) return } } diff --git a/main.go b/main.go index f0fadae..99ec93d 100644 --- a/main.go +++ b/main.go @@ -2,22 +2,56 @@ package main import ( "flag" - "log" + "fmt" + "os" + + log "github.com/fangdingjun/go-log" + "github.com/fangdingjun/go-log/formatters" + "github.com/fangdingjun/go-log/writers" ) var _config *conf func main() { var configfile string + var logfile string + var loglevel string + var logFileSize int64 + var logKeepCount int + flag.StringVar(&configfile, "c", "config.yaml", "config file") + flag.StringVar(&logfile, "log_file", "", "log file path, default to stdout") + flag.StringVar(&loglevel, "log_level", "INFO", "log level, levels are: \nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG") + flag.Int64Var(&logFileSize, "log_file_size", 10, "max log file size, MB") + flag.IntVar(&logKeepCount, "log_count", 10, "max count of log file to keep") flag.Parse() + if logfile != "" { + log.Default.Out = &writers.FixedSizeFileWriter{ + Name: logfile, + MaxSize: logFileSize * 1024 * 1024, + MaxCount: logKeepCount, + } + } + if loglevel != "" { + lvname, err := log.ParseLevel(loglevel) + if err != nil { + fmt.Fprintf(os.Stderr, "invalid level %s", loglevel) + os.Exit(1) + } + log.Default.Level = lvname + } + + log.Default.Formatter = &formatters.TextFormatter{ + TimeFormat: "2006-01-02 15:04:05.000", + } + cfg, err := loadConfig(configfile) if err != nil { log.Fatalf("load config file error: %s", err) } _config = cfg - + log.Debugf("config: %+v", _config) initHandler() select {} }