diff --git a/config/config.go b/config/config.go index 2e084345..3cb2ef29 100644 --- a/config/config.go +++ b/config/config.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "strconv" + "strings" "github.com/miniflux/miniflux/logger" ) @@ -54,6 +55,35 @@ func (c *Config) getInt(key string, fallback int) int { return v } +func (c *Config) parseBaseURL() { + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + return + } + + if baseURL[len(baseURL)-1:] == "/" { + baseURL = baseURL[:len(baseURL)-1] + } + + u, err := url.Parse(baseURL) + if err != nil { + logger.Error("Invalid BASE_URL: %v", err) + return + } + + scheme := strings.ToLower(u.Scheme) + if scheme != "https" && scheme != "http" { + logger.Error("Invalid BASE_URL: scheme must be http or https") + return + } + + c.baseURL = baseURL + c.basePath = u.Path + + u.Path = "" + c.rootURL = u.String() +} + // HasDebugMode returns true if debug mode is enabled. func (c *Config) HasDebugMode() bool { return c.get("DEBUG", "") != "" @@ -61,31 +91,16 @@ func (c *Config) HasDebugMode() bool { // BaseURL returns the application base URL with path. func (c *Config) BaseURL() string { - if c.baseURL == "" { - c.baseURL = c.get("BASE_URL", defaultBaseURL) - if c.baseURL[len(c.baseURL)-1:] == "/" { - c.baseURL = c.baseURL[:len(c.baseURL)-1] - } - } return c.baseURL } // RootURL returns the base URL without path. func (c *Config) RootURL() string { - if c.rootURL == "" { - u, _ := url.Parse(c.BaseURL()) - u.Path = "" - c.rootURL = u.String() - } return c.rootURL } // BasePath returns the application base path according to the base URL. func (c *Config) BasePath() string { - if c.basePath == "" { - u, _ := url.Parse(c.BaseURL()) - c.basePath = u.Path - } return c.basePath } @@ -204,5 +219,12 @@ func (c *Config) PocketConsumerKey(defaultValue string) string { // NewConfig returns a new Config. func NewConfig() *Config { - return &Config{IsHTTPS: os.Getenv("HTTPS") != ""} + cfg := &Config{ + baseURL: defaultBaseURL, + rootURL: defaultBaseURL, + IsHTTPS: os.Getenv("HTTPS") != "", + } + + cfg.parseBaseURL() + return cfg } diff --git a/config/config_test.go b/config/config_test.go index 2cfec81a..1f91add5 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -56,7 +56,7 @@ func TestCustomBaseURLWithTrailingSlash(t *testing.T) { } if cfg.RootURL() != "http://example.org" { - t.Fatalf(`Unexpected root URL, got "%s"`, cfg.BaseURL()) + t.Fatalf(`Unexpected root URL, got "%s"`, cfg.RootURL()) } if cfg.BasePath() != "/folder" { @@ -64,6 +64,42 @@ func TestCustomBaseURLWithTrailingSlash(t *testing.T) { } } +func TestBaseURLWithoutScheme(t *testing.T) { + os.Clearenv() + os.Setenv("BASE_URL", "example.org/folder/") + cfg := NewConfig() + + if cfg.BaseURL() != "http://localhost" { + t.Fatalf(`Unexpected base URL, got "%s"`, cfg.BaseURL()) + } + + if cfg.RootURL() != "http://localhost" { + t.Fatalf(`Unexpected root URL, got "%s"`, cfg.RootURL()) + } + + if cfg.BasePath() != "" { + t.Fatalf(`Unexpected base path, got "%s"`, cfg.BasePath()) + } +} + +func TestBaseURLWithInvalidScheme(t *testing.T) { + os.Clearenv() + os.Setenv("BASE_URL", "ftp://example.org/folder/") + cfg := NewConfig() + + if cfg.BaseURL() != "http://localhost" { + t.Fatalf(`Unexpected base URL, got "%s"`, cfg.BaseURL()) + } + + if cfg.RootURL() != "http://localhost" { + t.Fatalf(`Unexpected root URL, got "%s"`, cfg.RootURL()) + } + + if cfg.BasePath() != "" { + t.Fatalf(`Unexpected base path, got "%s"`, cfg.BasePath()) + } +} + func TestDefaultBaseURL(t *testing.T) { os.Clearenv() cfg := NewConfig()