diff --git a/customHeaders.go b/customHeaders.go new file mode 100644 index 0000000..0056021 --- /dev/null +++ b/customHeaders.go @@ -0,0 +1,100 @@ +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "strings" +) + +// HeaderConfigArray is the array which contains all the custom header rules +type HeaderConfigArray struct { + Configs []HeaderConfig `json:"configs"` +} + +// HeaderConfig is a single header rule specification +type HeaderConfig struct { + Path string `json:"path"` + FileExtension string `json:"fileExtension"` + Headers []HeaderDefiniton `json:"headers"` +} + +// HeaderDefiniton is a key value pair of a specified header rule +type HeaderDefiniton struct { + Key string `json:"key"` + Value string `json:"value"` +} + +var headerConfigs HeaderConfigArray + +func fileExists(filename string) bool { + info, err := os.Stat(filename) + if os.IsNotExist(err) { + return false + } + return !info.IsDir() +} + +func initHeaderConfig() bool { + headerConfigValid := false + + if fileExists("/config/headerConfig.json") { + jsonFile, err := os.Open("/config/headerConfig.json") + if err != nil { + fmt.Println("Cant't read header config file. Error:") + fmt.Println(err) + } else { + byteValue, _ := ioutil.ReadAll(jsonFile) + + json.Unmarshal(byteValue, &headerConfigs) + + if len(headerConfigs.Configs) > 0 { + headerConfigValid = true + fmt.Println("Found header config file. Rules:") + + for i := 0; i < len(headerConfigs.Configs); i++ { + configEntry := headerConfigs.Configs[i] + fmt.Println("==============================================================") + fmt.Println("Path: " + configEntry.Path) + fmt.Println("FileExtension: " + configEntry.FileExtension) + for j := 0; j < len(configEntry.Headers); j++ { + headerRule := configEntry.Headers[j] + fmt.Println(headerRule.Key, ":", headerRule.Value) + } + fmt.Println("==============================================================") + } + } else { + fmt.Println("No rules found in header config file.") + } + + } + jsonFile.Close() + } + + return headerConfigValid +} + +func customHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqFileExtension := filepath.Ext(r.URL.Path) + + for i := 0; i < len(headerConfigs.Configs); i++ { + configEntry := headerConfigs.Configs[i] + + fileMatch := configEntry.FileExtension == "*" || reqFileExtension == "."+configEntry.FileExtension + pathMatch := configEntry.Path == "*" || strings.HasPrefix(r.URL.Path, configEntry.Path) + + if fileMatch && pathMatch { + for j := 0; j < len(configEntry.Headers); j++ { + headerEntry := configEntry.Headers[j] + w.Header().Set(headerEntry.Key, headerEntry.Value) + } + } + } + + next.ServeHTTP(w, r) + }) +} diff --git a/main.go b/main.go index 40480b9..b2617e6 100644 --- a/main.go +++ b/main.go @@ -14,9 +14,6 @@ import ( "strconv" "strings" "sync" - "encoding/json" - "os" - "path/filepath" ) var ( @@ -70,15 +67,8 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } -func handleReq(h http.Handler, headerConfigs HeaderConfigs) http.Handler { +func handleReq(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - /* - log.Println("======================") - log.Println("path =>", r.URL.Path) - log.Println("extension =>", filepath.Ext(r.URL.Path)) - log.Println("pagedata =>", strings.HasPrefix(r.URL.Path, "/page-data")) - log.Println("======================") - */ if *httpsPromote && r.Header.Get("X-Forwarded-Proto") == "http" { http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusMovedPermanently) if *logRequest { @@ -91,91 +81,14 @@ func handleReq(h http.Handler, headerConfigs HeaderConfigs) http.Handler { log.Println(r.Method, r.URL.Path) } - // apply custom header config - if (len(headerConfigs.Configs) > 0) { - reqFileExtension := filepath.Ext(r.URL.Path) - - for i := 0; i < len(headerConfigs.Configs); i++ { - configEntry := headerConfigs.Configs[i] - - fileMatch := configEntry.FileExtension == "*" || reqFileExtension == "."+ configEntry.FileExtension - pathMatch := configEntry.Path == "*" || strings.HasPrefix(r.URL.Path, configEntry.Path) - - // log.Println("fileMatch", reqFileExtension, configEntry.FileExtension) - // log.Println("pathMatch", r.URL.Path, configEntry.Path) - - if fileMatch && pathMatch { - log.Println(r.URL.Path, configEntry.FileExtension, configEntry.Path) - for j := 0; j < len(configEntry.Headers); j++ { - headerEntry := configEntry.Headers[j] - - log.Println("add header", headerEntry.Key, headerEntry.Value, "to", r.URL.Path) - w.Header().Set(headerEntry.Key, headerEntry.Value) - } - } - } - } - h.ServeHTTP(w, r) }) } -type HeaderConfigs struct { - Configs []HConfig `json:"configs"` -} - -type HConfig struct { - Path string `json:"path"` - FileExtension string `json:"fileExtension"` - Headers []Header `json:"headers"` -} - -type Header struct { - Key string `json:"key"` - Value string `json:"value"` -} - - -func initHeaderConfig() HeaderConfigs { - if fileExists("/config/headerConfig.json") { - jsonFile, err := os.Open("/config/headerConfig.json") - if err != nil { - fmt.Println(err) - } - - byteValue, _ := ioutil.ReadAll(jsonFile) - - var headerConfigs HeaderConfigs - - json.Unmarshal(byteValue, &headerConfigs) - - fmt.Println("Found header config file. Rules:") - - for i := 0; i < len(headerConfigs.Configs); i++ { - fmt.Println("==============================================================") - fmt.Println("Path: " + headerConfigs.Configs[i].Path) - fmt.Println("FileExtension: " + headerConfigs.Configs[i].FileExtension) - for j := 0; j < len(headerConfigs.Configs); j++ { - headerRule := headerConfigs.Configs[i].Headers[j] - fmt.Println(headerRule.Key, ":", headerRule.Value) - } - fmt.Println("==============================================================") - } - - jsonFile.Close() - - return headerConfigs - } else { - return nil - } -} - func main() { flag.Parse() - headerConfigs := initHeaderConfig() - // sanity check if len(*setBasicAuth) != 0 && !*basicAuth { *basicAuth = true @@ -192,7 +105,7 @@ func main() { } } - handler := handleReq(http.FileServer(fileSystem), headerConfigs) + handler := handleReq(http.FileServer(fileSystem)) pathPrefix := "/" if len(*context) > 0 { @@ -210,6 +123,11 @@ func main() { handler = authMiddleware(handler) } + headerConfigValid := initHeaderConfig() + if headerConfigValid { + handler = customHeadersMiddleware(handler) + } + // Extra headers. if len(*headerFlag) > 0 { header, headerValue := parseHeaderFlag(*headerFlag)