package main import ( "bufio" "flag" "io" "log" "net" "os" "strings" "sync" "time" "git.worn.eu/guru/squid-rewriter/config" "git.worn.eu/guru/squid-rewriter/distro" "git.worn.eu/guru/squid-rewriter/squid" "git.worn.eu/guru/squid-rewriter/trie" ) var Flags struct { ConfigFile string LogFile string DumpFile string Verbose bool ListenAddress string } var ready = false var L *log.Logger func RewritesFromConfig(t *trie.Trie, cfgfile string) error { cfg, err := config.Load(cfgfile) if err != nil { return err } repoRewrites := make(map[string]map[string]string, 100) for _, r := range cfg.Rewrites { if r.Destination == "" { continue } if r.Distro != "" { distroName := r.Distro repoName := "main" if strings.Contains(distroName, ":") { f := strings.SplitN(distroName, ":", 2) distroName = f[0] repoName = f[1] } if _, ok := repoRewrites[distroName]; !ok { repoRewrites[distroName] = make(map[string]string) } repoRewrites[distroName][repoName] = r.Destination } } if len(repoRewrites) > 0 { repoMirrors := make(chan distro.RepoMirror) wg := sync.WaitGroup{} go func() { for r := range repoMirrors { d, ok := repoRewrites[r.Distro] if !ok { continue } destination, ok := d[r.Repo] if !ok { continue } t.Put(r.Url, destination) } }() for distroName, repos := range repoRewrites { repoList := []string{} for r := range repos { repoList = append(repoList, r) } d := distro.GetDistro(distroName) if d == nil { continue } L.Printf("Fetching %s mirrors %v", distroName, repoList) wg.Add(1) go d.FetchMirrors(repoList, repoMirrors, &wg) } wg.Wait() close(repoMirrors) if len(repoMirrors) > 0 { L.Println("Waiting for all repo rewrite rules to be processed...") for len(repoMirrors) > 0 { time.Sleep(5 * time.Millisecond) } } } for _, r := range cfg.Rewrites { if r.Destination == "" { continue } if len(r.Urls) > 0 { if Flags.Verbose { L.Printf("Loading static rewites for %s", r.Name) } for _, url := range r.Urls { t.Put(url, r.Destination) } } if r.Filename != "" { if Flags.Verbose { L.Printf("Loading rewrites from %s for %s", r.Filename, r.Name) } f, err := os.Open(r.Filename) if err == nil { defer f.Close() scanner := bufio.NewScanner(f) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line != "" && !strings.HasPrefix(line, "#") { if line != r.Destination { t.Put(line, r.Destination) } } } } else { L.Printf("Can't open file %s", r.Filename) } } } ready = true return nil } func Rewrite(t *trie.Trie, r io.ReadCloser, w io.WriteCloser, closeStreams bool) { scanner := bufio.NewScanner(r) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if Flags.Verbose { L.Println("Parsing request:", line) } req := squid.ParseRequest(line) if req == nil { continue } resp := req.MakeResponse() resp.Result = "ERR" if req.Url != "" && ready { prefix, dest := t.GetLongestPrefix(req.Url) if sdest, ok := dest.(string); ok { resp.Result = "OK" newUrl := sdest + req.Url[len(prefix):] resp.RewriteTo(newUrl) resp.StoreId(newUrl) if Flags.Verbose { L.Printf("Rewriting %s to %s", req.Url, newUrl) } } } w.Write([]byte(resp.Format() + "\n")) } if closeStreams { r.Close() w.Close() } } func init() { flag.StringVar(&Flags.ConfigFile, "c", "", "Path to rewrite config file") flag.StringVar(&Flags.LogFile, "log", "", "Path to log file") flag.StringVar(&Flags.DumpFile, "dump", "", "Path to dump file") flag.BoolVar(&Flags.Verbose, "v", false, "Verbose logging") flag.StringVar(&Flags.ListenAddress, "listen", "", "Listening address") } var RewriteFileNames []string = []string{"rewrites.yaml", "rewrites.yml", "/etc/squid/rewrites.yaml", "/etc/squid/rewrites.yml"} func LoadRewrites(t *trie.Trie) { var err error start := time.Now() if Flags.ConfigFile != "" { err = RewritesFromConfig(t, Flags.ConfigFile) if err != nil { L.Fatalf("Can't load the specified rewrite file: %s", Flags.ConfigFile) } } else { for _, filename := range RewriteFileNames { if Flags.Verbose { L.Printf("Trying to load rewrites from %s", filename) } err = RewritesFromConfig(t, filename) if err == nil { break } } if err != nil { L.Fatalf("Can't load any of the predefined rewrite files: %v", RewriteFileNames) } } elapsed := time.Since(start) L.Printf("Loaded %d unique rewrites in %s", t.Count(), elapsed) if Flags.DumpFile != "" { file, err := os.OpenFile(Flags.DumpFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { L.Fatal(err) } defer file.Close() t.Dump(file) } } func main() { flag.Parse() if Flags.LogFile != "" { file, err := os.OpenFile(Flags.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { L.Fatal(err) } defer file.Close() L = log.New(file, "", log.LstdFlags) } else { L = log.Default() } t := trie.NewTrie() go LoadRewrites(&t) if Flags.ListenAddress == "" { L.Println("Accepting requests from stdin") Rewrite(&t, os.Stdin, os.Stdout, false) L.Println("End of input stream. Exiting.") } else { listen, err := net.Listen("tcp", Flags.ListenAddress) if err != nil { L.Fatal(err) } defer listen.Close() for { conn, err := listen.Accept() if err != nil { log.Fatal(err) } go func() { client := conn.RemoteAddr().String() L.Printf("Accepting requests from client %s", client) Rewrite(&t, conn, conn, true) L.Printf("Client %s disconnected", client) }() } } }