diff --git a/main.go b/main.go index 8a712cf..a923aaf 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,9 @@ package main import ( "bufio" "flag" + "io" "log" + "net" "os" "strings" "sync" @@ -16,19 +18,20 @@ import ( ) var Flags struct { - ConfigFile string - LogFile string - DumpFile string - Verbose bool + ConfigFile string + LogFile string + DumpFile string + Verbose bool + ListenAddress string } +var ready = false var L *log.Logger -func RewritesFromConfig(cfgfile string) (*trie.Trie, error) { - t := trie.NewTrie() +func RewritesFromConfig(t *trie.Trie, cfgfile string) error { cfg, err := config.Load(cfgfile) if err != nil { - return &t, err + return err } repoRewrites := make(map[string]map[string]string, 100) @@ -128,12 +131,13 @@ func RewritesFromConfig(cfgfile string) (*trie.Trie, error) { } } - return &t, nil + ready = true + + return nil } -func Rewrite(t *trie.Trie) { - L.Println("Listening for requests") - scanner := bufio.NewScanner(os.Stdin) +func Rewrite(t *trie.Trie, r io.ReadCloser, w io.WriteCloser, closeStreams bool) { + scanner := bufio.NewScanner(r) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if Flags.Verbose { @@ -144,40 +148,43 @@ func Rewrite(t *trie.Trie) { continue } resp := req.MakeResponse() - if req.Url != "" { - resp.Result = "OK" + resp.Result = "ERR" + if req.Url != "" && ready { prefix, dest := t.GetLongestPrefix(req.Url) if sdest, ok := dest.(string); ok { + resp.Result = "OK" newUrl := sdest + req.Url[len(prefix):] resp.RewriteTo(newUrl) + resp.StoreId(newUrl) if Flags.Verbose { L.Printf("Rewriting %s to %s", req.Url, newUrl) } } - } else { - resp.Result = "ERR" } - os.Stdout.WriteString(resp.Format() + "\n") + w.Write([]byte(resp.Format() + "\n")) + } + if closeStreams { + r.Close() + w.Close() } - L.Println("End of input stream. Exiting.") } func init() { flag.StringVar(&Flags.ConfigFile, "c", "", "Path to rewrite config file") - flag.StringVar(&Flags.LogFile, "l", "", "Path to log file") + flag.StringVar(&Flags.LogFile, "log", "", "Path to log file") flag.StringVar(&Flags.DumpFile, "dump", "", "Path to dump file") flag.BoolVar(&Flags.Verbose, "v", false, "Verbose logging") + flag.StringVar(&Flags.ListenAddress, "listen", "", "Listening address") } var RewriteFileNames []string = []string{"rewrites.yaml", "rewrites.yml", "/etc/squid/rewrites.yaml", "/etc/squid/rewrites.yml"} -func LoadRewrites() *trie.Trie { - var t *trie.Trie +func LoadRewrites(t *trie.Trie) { var err error start := time.Now() if Flags.ConfigFile != "" { - t, err = RewritesFromConfig(Flags.ConfigFile) + err = RewritesFromConfig(t, Flags.ConfigFile) if err != nil { L.Fatalf("Can't load the specified rewrite file: %s", Flags.ConfigFile) } @@ -186,7 +193,7 @@ func LoadRewrites() *trie.Trie { if Flags.Verbose { L.Printf("Trying to load rewrites from %s", filename) } - t, err = RewritesFromConfig(filename) + err = RewritesFromConfig(t, filename) if err == nil { break } @@ -207,8 +214,6 @@ func LoadRewrites() *trie.Trie { t.Dump(file) } - - return t } func main() { @@ -224,5 +229,29 @@ func main() { } else { L = log.Default() } - Rewrite(LoadRewrites()) + t := trie.NewTrie() + go LoadRewrites(&t) + if Flags.ListenAddress == "" { + L.Println("Accepting requests from stdin") + Rewrite(&t, os.Stdin, os.Stdout, false) + L.Println("End of input stream. Exiting.") + } else { + listen, err := net.Listen("tcp", Flags.ListenAddress) + if err != nil { + L.Fatal(err) + } + defer listen.Close() + for { + conn, err := listen.Accept() + if err != nil { + log.Fatal(err) + } + go func() { + client := conn.RemoteAddr().String() + L.Printf("Accepting requests from client %s", client) + Rewrite(&t, conn, conn, true) + L.Printf("Client %s disconnected", client) + }() + } + } } diff --git a/squid/response.go b/squid/response.go index 32e1d80..d30a2bb 100644 --- a/squid/response.go +++ b/squid/response.go @@ -30,3 +30,7 @@ func (r *Response) SetArg(arg string, value string) { func (r *Response) RewriteTo(dest string) { r.SetArg("rewrite-url", dest) } + +func (r *Response) StoreId(id string) { + r.SetArg("store-id", id) +}