258 lines
5.6 KiB
Go
258 lines
5.6 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"flag"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.worn.eu/guru/squid-rewriter/config"
|
|
"git.worn.eu/guru/squid-rewriter/distro"
|
|
"git.worn.eu/guru/squid-rewriter/squid"
|
|
"git.worn.eu/guru/squid-rewriter/trie"
|
|
)
|
|
|
|
var Flags struct {
|
|
ConfigFile string
|
|
LogFile string
|
|
DumpFile string
|
|
Verbose bool
|
|
ListenAddress string
|
|
}
|
|
|
|
var ready = false
|
|
var L *log.Logger
|
|
|
|
func RewritesFromConfig(t *trie.Trie, cfgfile string) error {
|
|
cfg, err := config.Load(cfgfile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
repoRewrites := make(map[string]map[string]string, 100)
|
|
|
|
for _, r := range cfg.Rewrites {
|
|
if r.Destination == "" {
|
|
continue
|
|
}
|
|
if r.Distro != "" {
|
|
distroName := r.Distro
|
|
repoName := "main"
|
|
if strings.Contains(distroName, ":") {
|
|
f := strings.SplitN(distroName, ":", 2)
|
|
distroName = f[0]
|
|
repoName = f[1]
|
|
}
|
|
if _, ok := repoRewrites[distroName]; !ok {
|
|
repoRewrites[distroName] = make(map[string]string)
|
|
}
|
|
repoRewrites[distroName][repoName] = r.Destination
|
|
}
|
|
}
|
|
|
|
if len(repoRewrites) > 0 {
|
|
repoMirrors := make(chan distro.RepoMirror)
|
|
wg := sync.WaitGroup{}
|
|
|
|
go func() {
|
|
for r := range repoMirrors {
|
|
d, ok := repoRewrites[r.Distro]
|
|
if !ok {
|
|
continue
|
|
}
|
|
destination, ok := d[r.Repo]
|
|
if !ok {
|
|
continue
|
|
}
|
|
t.Put(r.Url, destination)
|
|
}
|
|
}()
|
|
|
|
for distroName, repos := range repoRewrites {
|
|
repoList := []string{}
|
|
for r := range repos {
|
|
repoList = append(repoList, r)
|
|
}
|
|
d := distro.GetDistro(distroName)
|
|
if d == nil {
|
|
continue
|
|
}
|
|
L.Printf("Fetching %s mirrors %v", distroName, repoList)
|
|
wg.Add(1)
|
|
go d.FetchMirrors(repoList, repoMirrors, &wg)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(repoMirrors)
|
|
if len(repoMirrors) > 0 {
|
|
L.Println("Waiting for all repo rewrite rules to be processed...")
|
|
for len(repoMirrors) > 0 {
|
|
time.Sleep(5 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, r := range cfg.Rewrites {
|
|
if r.Destination == "" {
|
|
continue
|
|
}
|
|
if len(r.Urls) > 0 {
|
|
if Flags.Verbose {
|
|
L.Printf("Loading static rewites for %s", r.Name)
|
|
}
|
|
for _, url := range r.Urls {
|
|
t.Put(url, r.Destination)
|
|
}
|
|
}
|
|
if r.Filename != "" {
|
|
if Flags.Verbose {
|
|
L.Printf("Loading rewrites from %s for %s", r.Filename, r.Name)
|
|
}
|
|
f, err := os.Open(r.Filename)
|
|
if err == nil {
|
|
defer f.Close()
|
|
scanner := bufio.NewScanner(f)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line != "" && !strings.HasPrefix(line, "#") {
|
|
if line != r.Destination {
|
|
t.Put(line, r.Destination)
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
L.Printf("Can't open file %s", r.Filename)
|
|
}
|
|
}
|
|
}
|
|
|
|
ready = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func Rewrite(t *trie.Trie, r io.ReadCloser, w io.WriteCloser, closeStreams bool) {
|
|
scanner := bufio.NewScanner(r)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if Flags.Verbose {
|
|
L.Println("Parsing request:", line)
|
|
}
|
|
req := squid.ParseRequest(line)
|
|
if req == nil {
|
|
continue
|
|
}
|
|
resp := req.MakeResponse()
|
|
resp.Result = "ERR"
|
|
if req.Url != "" && ready {
|
|
prefix, dest := t.GetLongestPrefix(req.Url)
|
|
if sdest, ok := dest.(string); ok {
|
|
resp.Result = "OK"
|
|
newUrl := sdest + req.Url[len(prefix):]
|
|
resp.RewriteTo(newUrl)
|
|
resp.StoreId(newUrl)
|
|
if Flags.Verbose {
|
|
L.Printf("Rewriting %s to %s", req.Url, newUrl)
|
|
}
|
|
}
|
|
}
|
|
w.Write([]byte(resp.Format() + "\n"))
|
|
}
|
|
if closeStreams {
|
|
r.Close()
|
|
w.Close()
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
flag.StringVar(&Flags.ConfigFile, "c", "", "Path to rewrite config file")
|
|
flag.StringVar(&Flags.LogFile, "log", "", "Path to log file")
|
|
flag.StringVar(&Flags.DumpFile, "dump", "", "Path to dump file")
|
|
flag.BoolVar(&Flags.Verbose, "v", false, "Verbose logging")
|
|
flag.StringVar(&Flags.ListenAddress, "listen", "", "Listening address")
|
|
}
|
|
|
|
var RewriteFileNames []string = []string{"rewrites.yaml", "rewrites.yml", "/etc/squid/rewrites.yaml", "/etc/squid/rewrites.yml"}
|
|
|
|
func LoadRewrites(t *trie.Trie) {
|
|
var err error
|
|
|
|
start := time.Now()
|
|
if Flags.ConfigFile != "" {
|
|
err = RewritesFromConfig(t, Flags.ConfigFile)
|
|
if err != nil {
|
|
L.Fatalf("Can't load the specified rewrite file: %s", Flags.ConfigFile)
|
|
}
|
|
} else {
|
|
for _, filename := range RewriteFileNames {
|
|
if Flags.Verbose {
|
|
L.Printf("Trying to load rewrites from %s", filename)
|
|
}
|
|
err = RewritesFromConfig(t, filename)
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
L.Fatalf("Can't load any of the predefined rewrite files: %v", RewriteFileNames)
|
|
}
|
|
}
|
|
elapsed := time.Since(start)
|
|
L.Printf("Loaded %d unique rewrites in %s", t.Count(), elapsed)
|
|
|
|
if Flags.DumpFile != "" {
|
|
file, err := os.OpenFile(Flags.DumpFile, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
L.Fatal(err)
|
|
}
|
|
defer file.Close()
|
|
|
|
t.Dump(file)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
if Flags.LogFile != "" {
|
|
file, err := os.OpenFile(Flags.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
L.Fatal(err)
|
|
}
|
|
defer file.Close()
|
|
|
|
L = log.New(file, "", log.LstdFlags)
|
|
} else {
|
|
L = log.Default()
|
|
}
|
|
t := trie.NewTrie()
|
|
go LoadRewrites(&t)
|
|
if Flags.ListenAddress == "" {
|
|
L.Println("Accepting requests from stdin")
|
|
Rewrite(&t, os.Stdin, os.Stdout, false)
|
|
L.Println("End of input stream. Exiting.")
|
|
} else {
|
|
listen, err := net.Listen("tcp", Flags.ListenAddress)
|
|
if err != nil {
|
|
L.Fatal(err)
|
|
}
|
|
defer listen.Close()
|
|
for {
|
|
conn, err := listen.Accept()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
go func() {
|
|
client := conn.RemoteAddr().String()
|
|
L.Printf("Accepting requests from client %s", client)
|
|
Rewrite(&t, conn, conn, true)
|
|
L.Printf("Client %s disconnected", client)
|
|
}()
|
|
}
|
|
}
|
|
}
|