Initial commit

master
dingjun 8 years ago
commit 30000054d0

2
.gitignore vendored

@ -0,0 +1,2 @@
*~
fileserver*

@ -0,0 +1,75 @@
package main
import (
"context"
"net"
"time"
)
var dialer *net.Dialer
func dial(network, addr string) (net.Conn, error) {
var err error
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if ip := net.ParseIP(host); ip != nil {
return dialer.Dial(network, addr)
}
ips, err := lookupHost(host)
if err != nil {
return nil, err
}
var conn net.Conn
for _, ip := range ips {
address := net.JoinHostPort(ip, port)
if conn, err = dialer.Dial(network, address); err == nil {
return conn, err
}
}
// return last error
return conn, err
}
func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
var err error
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if ip := net.ParseIP(host); ip != nil {
return dialer.DialContext(ctx, network, addr)
}
ips, err := lookupHost(host)
if err != nil {
return nil, err
}
var conn net.Conn
for _, ip := range ips {
address := net.JoinHostPort(ip, port)
if conn, err = dialer.DialContext(ctx, network, address); err == nil {
return conn, err
}
}
// return last error
return conn, err
}
func init() {
dialer = &net.Dialer{
Timeout: 2 * time.Second,
}
}

127
dns.go

@ -0,0 +1,127 @@
// +build linux darwin
package main
import (
"fmt"
"github.com/miekg/dns"
"log"
"time"
)
var clientConfig *dns.ClientConfig
var dnsClient *dns.Client
func lookupHost(host string) ([]string, error) {
var result = []string{}
var err error
ret, err1 := getAAAA(host)
if err1 == nil {
result = append(result, ret...)
} else {
err = err1
}
ret1, err2 := getA(host)
if err2 == nil {
result = append(result, ret1...)
} else {
err = err2
}
if len(result) > 0 {
return result, nil
}
if err == nil {
return nil, fmt.Errorf("dns lookup failed for %s", host)
}
return nil, err
}
func getA(host string) ([]string, error) {
var err error
var msg *dns.Msg
var result = []string{}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(host), dns.TypeA)
for _, srv := range clientConfig.Servers {
dnsserver := fmt.Sprintf("%s:%s", srv, clientConfig.Port)
msg, _, err = dnsClient.Exchange(m, dnsserver)
if err == nil {
break
} else {
log.Println(err)
}
}
if err != nil {
return result, err
}
for _, rr := range msg.Answer {
if a, ok := rr.(*dns.A); ok {
result = append(result, a.A.String())
}
}
return result, nil
}
func getAAAA(host string) ([]string, error) {
var err error
var msg *dns.Msg
var result = []string{}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
for _, srv := range clientConfig.Servers {
dnsserver := fmt.Sprintf("%s:%s", srv, clientConfig.Port)
msg, _, err = dnsClient.Exchange(m, dnsserver)
if err == nil {
break
} else {
log.Println(err)
}
}
if err != nil {
return result, err
}
for _, rr := range msg.Answer {
if aaaa, ok := rr.(*dns.AAAA); ok {
result = append(result, aaaa.AAAA.String())
}
}
return result, nil
}
func init() {
var err error
clientConfig, err = dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
clientConfig = &dns.ClientConfig{
Servers: []string{"8.8.8.8", "4.2.2.2"},
Port: "53",
Ndots: 1,
Timeout: 2,
Attempts: 3,
}
}
//clientConfig.Port = "53"
dnsClient = &dns.Client{
Net: "udp",
Timeout: time.Duration(clientConfig.Timeout) * time.Second,
UDPSize: 4096,
}
}

@ -0,0 +1,21 @@
package main
import (
"fmt"
//"net"
"testing"
)
func TestLookuphost(t *testing.T) {
for _, h := range []string{"www.ifeng.com", "www.taobao.com",
"www.baidu.com", "www.sina.com.cn", "www.163.com", "www.qq.com",
"www.google.com", "www.facebook.com", "twitter.com",
} {
ret, err := lookupHost(h)
if err != nil {
t.Errorf("%s: %s", h, err)
}
fmt.Printf("%#v\n", ret)
}
}

@ -0,0 +1,9 @@
package main
import (
"net"
)
func lookupHost(host string) ([]string, error) {
return net.LookupHost(host)
}

@ -0,0 +1,147 @@
package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"time"
)
var defaultTransport http.RoundTripper = &http.Transport{
DialContext: dialContext,
MaxIdleConns: 50,
IdleConnTimeout: 30 * time.Second,
MaxIdleConnsPerHost: 3,
//ResponseHeaderTimeout: 2 * time.Second,
}
func main() {
var docroot string
var enableProxy bool
var port int
curdir, err := os.Getwd()
if err != nil {
curdir = "."
}
flag.StringVar(&docroot, "docroot", curdir, "document root")
flag.BoolVar(&enableProxy, "enable_proxy", false, "enable proxy function")
flag.IntVar(&port, "port", 8080, "the port listen to")
flag.Parse()
http.Handle("/", http.FileServer(http.Dir(docroot)))
log.Printf("Listen on :%d", port)
log.Printf("document root %s", docroot)
if enableProxy {
log.Println("proxy enabled")
}
err = http.ListenAndServe(fmt.Sprintf(":%d", port), &handler{
enableProxy: enableProxy,
})
if err != nil {
log.Fatal(err)
}
}
type handler struct {
enableProxy bool
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.RequestURI[0] == '/' {
http.DefaultServeMux.ServeHTTP(w, r)
return
}
if !h.enableProxy {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "<h1>page not found!</h1>")
return
}
if r.Method == http.MethodConnect {
h.handleCONNECT(w, r)
} else {
h.handleHTTP(w, r)
}
}
func (h *handler) handleHTTP(w http.ResponseWriter, r *http.Request) {
var resp *http.Response
var err error
r.Header.Del("proxy-connection")
resp, err = defaultTransport.RoundTrip(r)
if err != nil {
log.Printf("RoundTrip: %s", err)
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintf(w, "%s", err)
return
}
defer resp.Body.Close()
hdr := w.Header()
//resp.Header.Del("connection")
for k, v := range resp.Header {
for _, v1 := range v {
hdr.Add(k, v1)
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func (h *handler) handleCONNECT(w http.ResponseWriter, r *http.Request) {
host := r.RequestURI
if !strings.Contains(host, ":") {
host = fmt.Sprintf("%s:443", host)
}
var conn net.Conn
var err error
conn, err = dial("tcp", host)
if err != nil {
log.Printf("net.dial: %s", err)
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintf(w, "dial to %s failed: %s", host, err)
return
}
hj, _ := w.(http.Hijacker)
conn1, _, _ := hj.Hijack()
fmt.Fprintf(conn1, "%s 200 connection established\r\n\r\n", r.Proto)
pipeAndClose(conn, conn1)
}
func pipeAndClose(r1, r2 io.ReadWriteCloser) {
ch := make(chan int, 2)
go func() {
io.Copy(r1, r2)
ch <- 1
}()
go func() {
io.Copy(r2, r1)
ch <- 1
}()
<-ch
}
Loading…
Cancel
Save