change log library

master
fangdingjun 7 years ago
parent d43a1f8494
commit d58d43e0e6

@ -7,19 +7,21 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/fangdingjun/go-log"
) )
func initHandler() { func initHandler() {
var _certs []tls.Certificate var _certs []tls.Certificate
for _, _c := range _config.Certificate { for _, _c := range _config.Certificate {
log.Debugf("load certificate %s, %s...", _c.Cert, _c.Key)
_cert, err := _c.load() _cert, err := _c.load()
if err != nil { if err != nil {
log.Println("load certificate failed", err) log.Errorln("load certificate failed", err)
continue continue
} }
_certs = append(_certs, _cert) _certs = append(_certs, _cert)
@ -41,7 +43,7 @@ func initHandler() {
for { for {
c, err := l.Accept() c, err := l.Accept()
if err != nil { if err != nil {
log.Printf("accept error: %s", err) log.Errorf("accept error: %s", err)
continue continue
} }
go handleConnection(c) go handleConnection(c)
@ -51,20 +53,26 @@ func initHandler() {
} }
func handleConnection(c net.Conn) { func handleConnection(c net.Conn) {
//log.Printf("connection from %s", c.RemoteAddr().String()) defer func() {
log.Debugf("close connection %s", c.RemoteAddr())
if err := c.Close(); err != nil {
log.Debugf("close connection %s error: %s", c.RemoteAddr(), err)
}
}()
log.Debugf("connection from %s\n", c.RemoteAddr().String())
tlsconn := c.(*tls.Conn) tlsconn := c.(*tls.Conn)
defer tlsconn.Close()
connstate := tlsconn.ConnectionState() connstate := tlsconn.ConnectionState()
for !connstate.HandshakeComplete { for !connstate.HandshakeComplete {
if err := tlsconn.Handshake(); err != nil { if err := tlsconn.Handshake(); err != nil {
log.Printf("handshake error: %s", err) log.Errorf("%s tls handshake error: %s", c.RemoteAddr(), err)
return return
} }
connstate = tlsconn.ConnectionState() connstate = tlsconn.ConnectionState()
} }
//log.Printf("handshake complete") log.Debugf("%s handshake complete", c.RemoteAddr())
servername := connstate.ServerName servername := connstate.ServerName
@ -79,13 +87,13 @@ func handleConnection(c net.Conn) {
if backend == nil { if backend == nil {
_b, err := url.Parse(_config.DefaultBackend) _b, err := url.Parse(_config.DefaultBackend)
if err != nil { if err != nil {
log.Printf("parse addr error: %s", err) log.Errorf("parse addr error: %s", err)
return return
} }
backend = _b backend = _b
} }
log.Printf("connection from %s, tls version 0x%x, sni: %s, forward to: %s\n", log.Debugf("connection from %s, tls version 0x%x, sni: %s, forward to: %s",
c.RemoteAddr().String(), c.RemoteAddr().String(),
connstate.Version, connstate.Version,
servername, servername,
@ -99,39 +107,33 @@ func handleForward(c *tls.Conn, b *url.URL) {
var remote net.Conn var remote net.Conn
var err error var err error
//log.Printf("forward to %s", b.String()) log.Debugf("%s forward to %s", c.RemoteAddr(), b.String())
switch b.Scheme { switch b.Scheme {
case "http":
httpForward(c, b)
return
case "tcp": case "tcp":
//log.Debugf("forward to tcp backend %s", b.Host)
remote, err = net.Dial("tcp", b.Host) remote, err = net.Dial("tcp", b.Host)
case "unix": case "unix":
//log.Debugf("forward to unix backend %s", b.Host)
remote, err = net.Dial("unix", b.Host) remote, err = net.Dial("unix", b.Host)
case "http":
if !strings.Contains(b.Host, ":") {
b.Host = fmt.Sprintf("%s:80", b.Host)
}
remote, err = net.Dial("tcp", b.Host)
if err != nil {
log.Printf("dial to %s error: %s", b.Host, err)
writeErrResponse(c, http.StatusBadGateway, err.Error())
return
}
httpForward(c, remote)
return
case "tls": case "tls":
//log.Debugf("forward to tls backend %s", b.Host)
h, _, _ := net.SplitHostPort(b.Host) h, _, _ := net.SplitHostPort(b.Host)
remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h}) remote, err = tls.Dial("tcp", b.Host, &tls.Config{ServerName: h})
default: default:
log.Printf("backend type '%s' is not supported", b.Scheme) log.Errorf("backend type '%s' is not supported", b.Scheme)
return return
} }
if err != nil { if err != nil {
log.Printf("dail to backend %s error: %s", b.String(), err) log.Errorf("dail to backend %s error: %s", b.String(), err)
return return
} }
if remote == nil { if remote == nil {
log.Println("remote is nil") log.Warnln("remote is nil")
return return
} }
//log.Println("begin data forward") //log.Println("begin data forward")
@ -173,18 +175,29 @@ func writeErrResponse(w io.Writer, status int, content string) {
res.Write(w) res.Write(w)
} }
func httpForward(r, b net.Conn) { func httpForward(r net.Conn, b *url.URL) {
defer b.Close() //log.Debugf("forward to http backend %s", b.Host)
if !strings.Contains(b.Host, ":") {
b.Host = fmt.Sprintf("%s:80", b.Host)
}
backend, err := net.Dial("tcp", b.Host)
if err != nil {
log.Errorf("dial to %s error: %s", b.Host, err)
// consume the client request,
// or client will get forcely close connection error
http.ReadRequest(bufio.NewReader(r))
writeErrResponse(r, http.StatusBadGateway, err.Error())
return
}
defer backend.Close()
rb := bufio.NewReader(r) rb := bufio.NewReader(r)
bb := bufio.NewReader(b) bb := bufio.NewReader(backend)
for { for {
req, err := http.ReadRequest(rb) req, err := http.ReadRequest(rb)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Printf("read http request error: %s", err) log.Errorf("read http request error: %s", err)
//fmt.Fprintf(b, "HTTP/1.1 504 Bad gateway\r\nConnection: close\r\n\r\n")
//writeErrResponse(b, http.StatusBadGateway, err.Error())
} }
return return
} }
@ -195,14 +208,14 @@ func httpForward(r, b net.Conn) {
req.Header.Add("X-Real-Ip", addr.IP.String()) req.Header.Add("X-Real-Ip", addr.IP.String())
//log.Printf("%+v\n", req.Header) //log.Printf("%+v\n", req.Header)
err = req.Write(b) err = req.Write(backend)
if req.Body != nil { if req.Body != nil {
req.Body.Close() req.Body.Close()
} }
if err != nil { if err != nil {
log.Printf("write request to backend error: %s", err) log.Errorf("write request to backend error: %s", err)
writeErrResponse(r, http.StatusBadGateway, err.Error()) writeErrResponse(r, http.StatusBadGateway, err.Error())
return return
} }
@ -210,7 +223,7 @@ func httpForward(r, b net.Conn) {
res, err := http.ReadResponse(bb, req) res, err := http.ReadResponse(bb, req)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Printf("read http response from backend error: %s", err) log.Errorf("read http response from backend error: %s", err)
} }
writeErrResponse(r, http.StatusBadGateway, "gateway error") writeErrResponse(r, http.StatusBadGateway, "gateway error")
return return
@ -223,7 +236,7 @@ func httpForward(r, b net.Conn) {
} }
if err != nil { if err != nil {
log.Printf("write response to client error: %s", err) log.Errorf("write response to client error: %s", err)
return return
} }
} }

@ -2,22 +2,56 @@ package main
import ( import (
"flag" "flag"
"log" "fmt"
"os"
log "github.com/fangdingjun/go-log"
"github.com/fangdingjun/go-log/formatters"
"github.com/fangdingjun/go-log/writers"
) )
var _config *conf var _config *conf
func main() { func main() {
var configfile string var configfile string
var logfile string
var loglevel string
var logFileSize int64
var logKeepCount int
flag.StringVar(&configfile, "c", "config.yaml", "config file") flag.StringVar(&configfile, "c", "config.yaml", "config file")
flag.StringVar(&logfile, "log_file", "", "log file path, default to stdout")
flag.StringVar(&loglevel, "log_level", "INFO", "log level, levels are: \nOFF, FATAL, PANIC, ERROR, WARN, INFO, DEBUG")
flag.Int64Var(&logFileSize, "log_file_size", 10, "max log file size, MB")
flag.IntVar(&logKeepCount, "log_count", 10, "max count of log file to keep")
flag.Parse() flag.Parse()
if logfile != "" {
log.Default.Out = &writers.FixedSizeFileWriter{
Name: logfile,
MaxSize: logFileSize * 1024 * 1024,
MaxCount: logKeepCount,
}
}
if loglevel != "" {
lvname, err := log.ParseLevel(loglevel)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid level %s", loglevel)
os.Exit(1)
}
log.Default.Level = lvname
}
log.Default.Formatter = &formatters.TextFormatter{
TimeFormat: "2006-01-02 15:04:05.000",
}
cfg, err := loadConfig(configfile) cfg, err := loadConfig(configfile)
if err != nil { if err != nil {
log.Fatalf("load config file error: %s", err) log.Fatalf("load config file error: %s", err)
} }
_config = cfg _config = cfg
log.Debugf("config: %+v", _config)
initHandler() initHandler()
select {} select {}
} }

Loading…
Cancel
Save