You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

173 lines
4.0 KiB

package config
import (
"errors"
"fmt"
"strings"
validation "github.com/go-ozzo/ozzo-validation"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/spf13/viper"
)
const (
PRODENV = "production"
TESTENV = "test"
DEVENV = "development"
MYSQL = "mysql"
POSTGRES = "postgres"
MSSQL = "mssql"
)
var Sample = `application:
environment: ` + PRODENV + `|` + TESTENV + `|` + DEVENV + `
listenAddress: 127.0.0.1
port: 5100
requestIDHeaderName: X-Request-Id
database:
host: 127.0.0.1
port: 3306
dialect: ` + MYSQL + `|` + POSTGRES + `|` + MSSQL + `
database: skel
user: skel
password: secret
maxOpenConn: 90
maxIdleConn: 20
maxLifeTime: 1800
`
func setDefaults() {
viper.SetDefault("application.environment", PRODENV)
viper.SetDefault("application.listenAddress", "127.0.0.1")
viper.SetDefault("application.port", "8080")
viper.SetDefault("database.host", "127.0.0.1")
viper.SetDefault("database.maxOpenConn", "90")
viper.SetDefault("database.maxIdleConn", "10")
viper.SetDefault("database.maxLifeTime", "1800")
}
func setCobraFlags(flags map[string]interface{}) {
for key, value := range flags {
switch v := value.(type) {
case *string:
if *v != "" {
viper.Set(key, *v)
}
case *int:
if *v > 0 {
viper.Set(key, *v)
}
}
}
}
func Initialize(appName string, cfgFile string, flags map[string]interface{}) error {
viper.SetConfigName(appName)
switch cfgFile {
case "":
viper.AddConfigPath("/etc/" + appName + "/")
viper.AddConfigPath("$HOME/." + appName)
viper.AddConfigPath(".")
default:
viper.SetConfigFile(cfgFile)
}
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return errors.New("configuration broken or missing(" + err.Error() + ")")
}
}
viper.SetEnvPrefix(appName)
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
setCobraFlags(flags)
setDefaults()
return validate()
}
func validate() error {
results := []interface{}{
validateEnvironment(viper.GetString("application.environment")),
validateListenAddress(viper.GetString("application.listenaddress")),
validatePort("application", viper.GetString("application.port")),
validateHost("database", viper.GetString("database.host")),
validatePort("database", viper.GetString("database.port")),
validateDatabaseDialect(viper.GetString("database.dialect")),
validateUsername(viper.GetString("database.user")),
}
response := []string{}
for _, result := range results {
if err, ok := result.(error); ok {
response = append(response, err.Error())
}
}
if len(response) > 0 {
return errors.New(strings.Join(response, ", "))
}
return nil
}
func validateEnvironment(environment string) error {
if err := validation.Validate(
environment,
validation.In(PRODENV, TESTENV, DEVENV),
); err != nil {
return fmt.Errorf(
"application.environment (must be one of %s, %s or %s)",
PRODENV, TESTENV, DEVENV,
)
}
return nil
}
func validateDatabaseDialect(dialect string) error {
if err := validation.Validate(
dialect,
validation.In(MYSQL, POSTGRES, MSSQL, "sqlite"),
); err != nil {
return fmt.Errorf(
"database.dialect (must be one of %s, %s or %s)",
MYSQL, POSTGRES, MSSQL,
)
}
return nil
}
func validateListenAddress(listenAddress string) error {
if err := validation.Validate(
listenAddress,
is.Host,
); err != nil {
return errors.New("application.listenaddress (" + err.Error() + ")")
}
return nil
}
func validateUsername(username string) error {
if err := validation.Validate(
username,
validation.Required,
); err != nil {
return errors.New("database.user (" + err.Error() + ")")
}
return nil
}
func validatePort(parent string, port string) error {
if err := validation.Validate(
port,
is.Port,
); err != nil {
return errors.New(parent + ".port (" + err.Error() + ")")
}
return nil
}
func validateHost(parent string, host string) error {
if err := validation.Validate(
host,
is.Host,
); err != nil {
return errors.New(parent + ".host (" + err.Error() + ")")
}
return nil
}