package config import ( "errors" "fmt" "strings" validation "github.com/go-ozzo/ozzo-validation" "github.com/go-ozzo/ozzo-validation/v4/is" "github.com/spf13/viper" ) const ( PRODENV = "production" TESTENV = "test" DEVENV = "development" MYSQL = "mysql" POSTGRES = "postgres" MSSQL = "mssql" ) var Sample = `application: environment: ` + PRODENV + `|` + TESTENV + `|` + DEVENV + ` listenAddress: 127.0.0.1 port: 5100 requestIDHeaderName: X-Request-Id database: host: 127.0.0.1 port: 3306 dialect: ` + MYSQL + `|` + POSTGRES + `|` + MSSQL + ` database: skel user: skel password: secret maxOpenConn: 90 maxIdleConn: 20 maxLifeTime: 1800 ` func setDefaults() { viper.SetDefault("application.environment", PRODENV) viper.SetDefault("application.listenAddress", "127.0.0.1") viper.SetDefault("application.port", "8080") viper.SetDefault("database.host", "127.0.0.1") viper.SetDefault("database.maxOpenConn", "90") viper.SetDefault("database.maxIdleConn", "10") viper.SetDefault("database.maxLifeTime", "1800") } func setCobraFlags(flags map[string]interface{}) { for key, value := range flags { switch v := value.(type) { case *string: if *v != "" { viper.Set(key, *v) } case *int: if *v > 0 { viper.Set(key, *v) } } } } func Initialize(appName string, cfgFile string, flags map[string]interface{}) error { viper.SetConfigName(appName) switch cfgFile { case "": viper.AddConfigPath("/etc/" + appName + "/") viper.AddConfigPath("$HOME/." + appName) viper.AddConfigPath(".") default: viper.SetConfigFile(cfgFile) } if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return errors.New("configuration broken or missing(" + err.Error() + ")") } } viper.SetEnvPrefix(appName) viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() setCobraFlags(flags) setDefaults() return validate() } func validate() error { results := []interface{}{ validateEnvironment(viper.GetString("application.environment")), validateListenAddress(viper.GetString("application.listenaddress")), validatePort("application", viper.GetString("application.port")), validateHost("database", viper.GetString("database.host")), validatePort("database", viper.GetString("database.port")), validateDatabaseDialect(viper.GetString("database.dialect")), validateUsername(viper.GetString("database.user")), } response := []string{} for _, result := range results { if err, ok := result.(error); ok { response = append(response, err.Error()) } } if len(response) > 0 { return errors.New(strings.Join(response, ", ")) } return nil } func validateEnvironment(environment string) error { if err := validation.Validate( environment, validation.In(PRODENV, TESTENV, DEVENV), ); err != nil { return fmt.Errorf( "application.environment (must be one of %s, %s or %s)", PRODENV, TESTENV, DEVENV, ) } return nil } func validateDatabaseDialect(dialect string) error { if err := validation.Validate( dialect, validation.In(MYSQL, POSTGRES, MSSQL, "sqlite"), ); err != nil { return fmt.Errorf( "database.dialect (must be one of %s, %s or %s)", MYSQL, POSTGRES, MSSQL, ) } return nil } func validateListenAddress(listenAddress string) error { if err := validation.Validate( listenAddress, is.Host, ); err != nil { return errors.New("application.listenaddress (" + err.Error() + ")") } return nil } func validateUsername(username string) error { if err := validation.Validate( username, validation.Required, ); err != nil { return errors.New("database.user (" + err.Error() + ")") } return nil } func validatePort(parent string, port string) error { if err := validation.Validate( port, is.Port, ); err != nil { return errors.New(parent + ".port (" + err.Error() + ")") } return nil } func validateHost(parent string, host string) error { if err := validation.Validate( host, is.Host, ); err != nil { return errors.New(parent + ".host (" + err.Error() + ")") } return nil }