diff --git a/lib/config/config.go b/lib/config/config.go index 08f14f9..974e9eb 100644 --- a/lib/config/config.go +++ b/lib/config/config.go @@ -4,9 +4,9 @@ import ( "os" "path/filepath" + "github.com/go-i2p/go-i2p/lib/util" "github.com/go-i2p/logger" "github.com/spf13/viper" - "gopkg.in/yaml.v3" ) var ( @@ -17,48 +17,13 @@ var ( const GOI2P_BASE_DIR = ".go-i2p" func InitConfig() { - defaultConfigDir := filepath.Join(os.Getenv("HOME"), GOI2P_BASE_DIR) - defaultConfigFile := filepath.Join(defaultConfigDir, "config.yaml") if CfgFile != "" { // Use config file from the flag viper.SetConfigFile(CfgFile) } else { - // Create default config if it doesn't exist - if _, err := os.Stat(defaultConfigFile); os.IsNotExist(err) { - // Ensure directory exists - if err := os.MkdirAll(defaultConfigDir, 0o755); err != nil { - log.Fatalf("Could not create config directory: %s", err) - } - - // Create default configuration - defaultConfig := struct { - BaseDir string `yaml:"base_dir"` - WorkingDir string `yaml:"working_dir"` - NetDB NetDbConfig `yaml:"netdb"` - Bootstrap BootstrapConfig `yaml:"bootstrap"` - }{ - BaseDir: DefaultRouterConfig().BaseDir, - WorkingDir: DefaultRouterConfig().WorkingDir, - NetDB: *DefaultRouterConfig().NetDb, - Bootstrap: *DefaultRouterConfig().Bootstrap, - } - - yamlData, err := yaml.Marshal(defaultConfig) - if err != nil { - log.Fatalf("Could not marshal default config: %s", err) - } - - // Write default config file - if err := os.WriteFile(defaultConfigFile, yamlData, 0o644); err != nil { - log.Fatalf("Could not write default config file: %s", err) - } - - log.Debugf("Created default configuration at: %s", defaultConfigFile) - } - - // Set up viper to use the config file - viper.AddConfigPath(defaultConfigDir) + // Set up viper to use the default config path $HOME/.go-ip/ + viper.AddConfigPath(BuildI2PDirPath()) viper.SetConfigName("config") viper.SetConfigType("yaml") } @@ -66,11 +31,8 @@ func InitConfig() { // Load defaults setDefaults() - if err := viper.ReadInConfig(); err != nil { - log.Warnf("Error reading config file: %s", err) - } else { - log.Debugf("Using config file: %s", viper.ConfigFileUsed()) - } + // handle config file creating it if needed + handleConfigFile() // Update RouterConfigProperties UpdateRouterConfig() @@ -111,3 +73,41 @@ func UpdateRouterConfig() { ReseedServers: reseedServers, } } + +func createDefaultConfig(defaultConfigDir string) { + + defaultConfigFile := filepath.Join(defaultConfigDir, "config.yaml") + // Ensure directory exists + if err := os.MkdirAll(defaultConfigDir, 0o755); err != nil { + log.Fatalf("Could not create config directory: %s", err) + } + + // Write current config file + if err := viper.WriteConfig(); err != nil { + log.Fatalf("Could not write default config file: %s", err) + } + + log.Debugf("Created default configuration at: %s", defaultConfigFile) + +} + +func handleConfigFile() { + if err := viper.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + if CfgFile != "" { + log.Fatalf("Config file %s is not found: %s", CfgFile, err) + } else { + createDefaultConfig(BuildI2PDirPath()) + } + } else { + log.Fatalf("Error reading config file: %s", err) + } + } else { + log.Debugf("Using config file: %s", viper.ConfigFileUsed()) + } + +} + +func BuildI2PDirPath() string { + return filepath.Join(util.UserHome(), GOI2P_BASE_DIR) +} diff --git a/lib/config/router.go b/lib/config/router.go index 4510e5a..976bb04 100644 --- a/lib/config/router.go +++ b/lib/config/router.go @@ -1,7 +1,6 @@ package config import ( - "os" "path/filepath" ) @@ -17,20 +16,12 @@ type RouterConfig struct { Bootstrap *BootstrapConfig } -func home() string { - h, err := os.UserHomeDir() - if err != nil { - panic(err) - } - return h -} - func defaultBase() string { - return filepath.Join(home(), GOI2P_BASE_DIR, "base") + return filepath.Join(BuildI2PDirPath(), "base") } func defaultConfig() string { - return filepath.Join(home(), GOI2P_BASE_DIR, "config") + return filepath.Join(BuildI2PDirPath(), "config") } // defaults for router diff --git a/lib/util/home.go b/lib/util/home.go new file mode 100644 index 0000000..fdee95a --- /dev/null +++ b/lib/util/home.go @@ -0,0 +1,15 @@ +package util + +import ( + "log" + "os" +) + +func UserHome() string { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Fatalf("Unable to get current user's home directory. $HOME environment variable issue? %s", err) + } + + return homeDir +}