1
0
Fork 0
squid-rewriter/main.go

198 lines
4.2 KiB
Go

package main
import (
"bufio"
"flag"
"log"
"os"
"strings"
"sync"
"time"
"git.worn.eu/guru/squid-rewriter/config"
"git.worn.eu/guru/squid-rewriter/distro"
"git.worn.eu/guru/squid-rewriter/squid"
"git.worn.eu/guru/squid-rewriter/trie"
)
var Flags struct {
ConfigFile string
Verbose bool
}
func RewritesFromConfig(cfgfile string) (*trie.Trie, error) {
t := trie.NewTrie()
cfg, err := config.Load(cfgfile)
if err != nil {
return &t, err
}
repoRewrites := make(map[string]map[string]string, 100)
for _, r := range cfg.Rewrites {
if r.Destination == "" {
continue
}
if r.Distro != "" {
distroName := r.Distro
repoName := "main"
if strings.Contains(distroName, ":") {
f := strings.SplitN(distroName, ":", 2)
distroName = f[0]
repoName = f[1]
}
if _, ok := repoRewrites[distroName]; !ok {
repoRewrites[distroName] = make(map[string]string)
}
repoRewrites[distroName][repoName] = r.Destination
}
}
if len(repoRewrites) > 0 {
repoMirrors := make(chan distro.RepoMirror)
wg := sync.WaitGroup{}
go func() {
for r := range repoMirrors {
d, ok := repoRewrites[r.Distro]
if !ok {
continue
}
destination, ok := d[r.Repo]
if !ok {
continue
}
t.Put(r.Url, destination)
}
}()
for distroName, repos := range repoRewrites {
repoList := []string{}
for r := range repos {
repoList = append(repoList, r)
}
d := distro.GetDistro(distroName)
if d == nil {
continue
}
log.Printf("Fetching %s mirrors %v", distroName, repoList)
wg.Add(1)
go d.FetchMirrors(repoList, repoMirrors, &wg)
}
wg.Wait()
close(repoMirrors)
if len(repoMirrors) > 0 {
log.Println("Waiting for all repo rewrite rules to be processed...")
for len(repoMirrors) > 0 {
time.Sleep(5 * time.Millisecond)
}
}
}
for _, r := range cfg.Rewrites {
if r.Destination == "" {
continue
}
if len(r.Urls) > 0 {
if Flags.Verbose {
log.Printf("Loading static rewites for %s", r.Name)
}
for _, url := range r.Urls {
t.Put(url, r.Destination)
}
}
if r.Filename != "" {
if Flags.Verbose {
log.Printf("Loading rewrites from %s for %s", r.Filename, r.Name)
}
f, err := os.Open(r.Filename)
if err == nil {
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" && !strings.HasPrefix(line, "#") {
if line != r.Destination {
t.Put(line, r.Destination)
}
}
}
} else {
log.Printf("Can't open file %s", r.Filename)
}
}
}
return &t, nil
}
func Rewrite(t *trie.Trie) {
log.Println("Listening for requests")
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
req := squid.ParseRequest(scanner.Text())
if req == nil {
continue
}
resp := req.MakeResponse()
if req.Url != "" {
resp.Result = "OK"
prefix, dest := t.GetLongestPrefix(req.Url)
if sdest, ok := dest.(string); ok {
newUrl := sdest + req.Url[len(prefix):]
resp.RewriteTo(newUrl)
if Flags.Verbose {
log.Printf("Rewriting %s to %s", req.Url, newUrl)
}
}
} else {
resp.Result = "ERR"
}
os.Stdout.WriteString(resp.Format() + "\n")
}
log.Println("End of input stream. Exiting.")
}
func init() {
flag.StringVar(&Flags.ConfigFile, "c", "", "Path to rewrite config file")
flag.BoolVar(&Flags.Verbose, "v", false, "Verbose logging")
}
var RewriteFileNames []string = []string{"rewrites.yaml", "rewrites.yml", "/etc/squid/rewrites.yaml", "/etc/squid/rewrites.yml"}
func LoadRewrites() *trie.Trie {
var t *trie.Trie
var err error
start := time.Now()
if Flags.ConfigFile != "" {
t, err = RewritesFromConfig(Flags.ConfigFile)
if err != nil {
log.Fatalf("Can't load the specified rewrite file: %s", Flags.ConfigFile)
}
} else {
for _, filename := range RewriteFileNames {
if Flags.Verbose {
log.Printf("Trying to load rewrites from %s", filename)
}
t, err = RewritesFromConfig(filename)
if err == nil {
break
}
}
if err != nil {
log.Fatalf("Can't load any of the predefined rewrite files: %v", RewriteFileNames)
}
}
elapsed := time.Since(start)
log.Printf("Loaded %d unique rewrites in %s", t.Count(), elapsed)
return t
}
func main() {
flag.Parse()
Rewrite(LoadRewrites())
}