Optionally serve rewrite requests from tcp socket
This commit is contained in:
parent
971fb9efa4
commit
8bb615b741
79
main.go
79
main.go
|
@ -3,7 +3,9 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -16,19 +18,20 @@ import (
|
|||
)
|
||||
|
||||
var Flags struct {
|
||||
ConfigFile string
|
||||
LogFile string
|
||||
DumpFile string
|
||||
Verbose bool
|
||||
ConfigFile string
|
||||
LogFile string
|
||||
DumpFile string
|
||||
Verbose bool
|
||||
ListenAddress string
|
||||
}
|
||||
|
||||
var ready = false
|
||||
var L *log.Logger
|
||||
|
||||
func RewritesFromConfig(cfgfile string) (*trie.Trie, error) {
|
||||
t := trie.NewTrie()
|
||||
func RewritesFromConfig(t *trie.Trie, cfgfile string) error {
|
||||
cfg, err := config.Load(cfgfile)
|
||||
if err != nil {
|
||||
return &t, err
|
||||
return err
|
||||
}
|
||||
|
||||
repoRewrites := make(map[string]map[string]string, 100)
|
||||
|
@ -128,12 +131,13 @@ func RewritesFromConfig(cfgfile string) (*trie.Trie, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
ready = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Rewrite(t *trie.Trie) {
|
||||
L.Println("Listening for requests")
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
func Rewrite(t *trie.Trie, r io.ReadCloser, w io.WriteCloser, closeStreams bool) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if Flags.Verbose {
|
||||
|
@ -144,40 +148,43 @@ func Rewrite(t *trie.Trie) {
|
|||
continue
|
||||
}
|
||||
resp := req.MakeResponse()
|
||||
if req.Url != "" {
|
||||
resp.Result = "OK"
|
||||
resp.Result = "ERR"
|
||||
if req.Url != "" && ready {
|
||||
prefix, dest := t.GetLongestPrefix(req.Url)
|
||||
if sdest, ok := dest.(string); ok {
|
||||
resp.Result = "OK"
|
||||
newUrl := sdest + req.Url[len(prefix):]
|
||||
resp.RewriteTo(newUrl)
|
||||
resp.StoreId(newUrl)
|
||||
if Flags.Verbose {
|
||||
L.Printf("Rewriting %s to %s", req.Url, newUrl)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resp.Result = "ERR"
|
||||
}
|
||||
os.Stdout.WriteString(resp.Format() + "\n")
|
||||
w.Write([]byte(resp.Format() + "\n"))
|
||||
}
|
||||
if closeStreams {
|
||||
r.Close()
|
||||
w.Close()
|
||||
}
|
||||
L.Println("End of input stream. Exiting.")
|
||||
}
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&Flags.ConfigFile, "c", "", "Path to rewrite config file")
|
||||
flag.StringVar(&Flags.LogFile, "l", "", "Path to log file")
|
||||
flag.StringVar(&Flags.LogFile, "log", "", "Path to log file")
|
||||
flag.StringVar(&Flags.DumpFile, "dump", "", "Path to dump file")
|
||||
flag.BoolVar(&Flags.Verbose, "v", false, "Verbose logging")
|
||||
flag.StringVar(&Flags.ListenAddress, "listen", "", "Listening address")
|
||||
}
|
||||
|
||||
var RewriteFileNames []string = []string{"rewrites.yaml", "rewrites.yml", "/etc/squid/rewrites.yaml", "/etc/squid/rewrites.yml"}
|
||||
|
||||
func LoadRewrites() *trie.Trie {
|
||||
var t *trie.Trie
|
||||
func LoadRewrites(t *trie.Trie) {
|
||||
var err error
|
||||
|
||||
start := time.Now()
|
||||
if Flags.ConfigFile != "" {
|
||||
t, err = RewritesFromConfig(Flags.ConfigFile)
|
||||
err = RewritesFromConfig(t, Flags.ConfigFile)
|
||||
if err != nil {
|
||||
L.Fatalf("Can't load the specified rewrite file: %s", Flags.ConfigFile)
|
||||
}
|
||||
|
@ -186,7 +193,7 @@ func LoadRewrites() *trie.Trie {
|
|||
if Flags.Verbose {
|
||||
L.Printf("Trying to load rewrites from %s", filename)
|
||||
}
|
||||
t, err = RewritesFromConfig(filename)
|
||||
err = RewritesFromConfig(t, filename)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
@ -207,8 +214,6 @@ func LoadRewrites() *trie.Trie {
|
|||
|
||||
t.Dump(file)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -224,5 +229,29 @@ func main() {
|
|||
} else {
|
||||
L = log.Default()
|
||||
}
|
||||
Rewrite(LoadRewrites())
|
||||
t := trie.NewTrie()
|
||||
go LoadRewrites(&t)
|
||||
if Flags.ListenAddress == "" {
|
||||
L.Println("Accepting requests from stdin")
|
||||
Rewrite(&t, os.Stdin, os.Stdout, false)
|
||||
L.Println("End of input stream. Exiting.")
|
||||
} else {
|
||||
listen, err := net.Listen("tcp", Flags.ListenAddress)
|
||||
if err != nil {
|
||||
L.Fatal(err)
|
||||
}
|
||||
defer listen.Close()
|
||||
for {
|
||||
conn, err := listen.Accept()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
client := conn.RemoteAddr().String()
|
||||
L.Printf("Accepting requests from client %s", client)
|
||||
Rewrite(&t, conn, conn, true)
|
||||
L.Printf("Client %s disconnected", client)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,3 +30,7 @@ func (r *Response) SetArg(arg string, value string) {
|
|||
func (r *Response) RewriteTo(dest string) {
|
||||
r.SetArg("rewrite-url", dest)
|
||||
}
|
||||
|
||||
func (r *Response) StoreId(id string) {
|
||||
r.SetArg("store-id", id)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue