diff --git a/cmd/gofs.go b/cmd/gofs.go index 898e175e..daddae27 100644 --- a/cmd/gofs.go +++ b/cmd/gofs.go @@ -59,7 +59,6 @@ func runWithConfig(c conf.Config, result result.Result) { }() cp := &c - conf.GlobalConfig = cp if err = parseConfigFile(cp); err != nil { result.InitDoneWithError(err) @@ -91,6 +90,8 @@ func runWithConfig(c conf.Config, result result.Result) { return } + log.ErrorIf(conf.SetGlobalConfig(cp), "set global config error => %s", cp.FileServerAddr) + // kill parent process if c.KillPPid { daemon.KillPPid() diff --git a/conf/global.go b/conf/global.go index a2a5fa14..3a0ed9a1 100644 --- a/conf/global.go +++ b/conf/global.go @@ -1,6 +1,52 @@ package conf +import ( + "errors" + "sync" +) + +var ( + globalConfigSet = &configSet{ + m: make(map[string]*Config), + } +) + var ( - // GlobalConfig the global config of the program, initial by flags or config file - GlobalConfig *Config + errConfigIsNil = errors.New("the config is nil") + errConfigExist = errors.New("the config exists") ) + +type configSet struct { + m map[string]*Config + mu sync.RWMutex +} + +func (cs *configSet) setGlobalConfig(c *Config) error { + if c == nil { + return errConfigIsNil + } + addr := c.FileServerAddr + cs.mu.Lock() + defer cs.mu.Unlock() + if _, ok := cs.m[addr]; ok { + return errConfigExist + } + cs.m[addr] = c + return nil +} + +func (cs *configSet) getGlobalConfig(addr string) *Config { + cs.mu.RLock() + defer cs.mu.RUnlock() + return cs.m[addr] +} + +// SetGlobalConfig set the global config once per web server +func SetGlobalConfig(c *Config) error { + return globalConfigSet.setGlobalConfig(c) +} + +// GetGlobalConfig get the global config by web server address +func GetGlobalConfig(addr string) *Config { + return globalConfigSet.getGlobalConfig(addr) +} diff --git a/conf/global_test.go b/conf/global_test.go new file mode 100644 index 00000000..76bc0118 --- /dev/null +++ b/conf/global_test.go @@ -0,0 +1,52 @@ +package conf + +import ( + "testing" +) + +func TestSetGlobalConfig(t *testing.T) { + testCases := []struct { + name string + config *Config + err error + }{ + {"normal address", &Config{FileServerAddr: ":8080"}, nil}, + {"empty address", &Config{FileServerAddr: ""}, nil}, + {"nil config", nil, errConfigIsNil}, + {"config exists", &Config{FileServerAddr: ":8080"}, errConfigExist}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := SetGlobalConfig(tc.config) + if err != tc.err { + t.Errorf("expect to get error %v, but get %v", tc.err, err) + } + }) + } +} + +func TestGetGlobalConfig(t *testing.T) { + if err := SetGlobalConfig(&Config{FileServerAddr: ":8088"}); err != nil { + t.Errorf("call SetGlobalConfig error") + return + } + testCases := []struct { + name string + address string + exist bool + }{ + {"normal address", ":8088", true}, + {"not exist config", ":8000", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + config := GetGlobalConfig(tc.address) + exist := config != nil + if exist != tc.exist { + t.Errorf("expect to get config %v, but get %v", tc.exist, exist) + } + }) + } +} diff --git a/server/handler/manage_handler.go b/server/handler/manage_handler.go index a3ce4ac4..5e451acf 100644 --- a/server/handler/manage_handler.go +++ b/server/handler/manage_handler.go @@ -13,35 +13,44 @@ import ( type manageHandler struct { logger log.Logger + addr string } // NewManageHandlerFunc returns a gin.HandlerFunc that shows the application config -func NewManageHandlerFunc(logger log.Logger) gin.HandlerFunc { +func NewManageHandlerFunc(logger log.Logger, addr string) gin.HandlerFunc { return (&manageHandler{ logger: logger, + addr: addr, }).Handle } func (h *manageHandler) Handle(c *gin.Context) { format := strings.ToLower(c.Query(server.ParamFormat)) + var result server.ApiResult // copy the config and mask the user info for security - config := *conf.GlobalConfig - mask := "******" - if len(config.Users) > 0 { - config.Users = mask - } - if len(config.SessionConnection) > 0 { - config.SessionConnection = mask - } - if len(config.EncryptSecret) > 0 { - config.EncryptSecret = mask - } - if len(config.DecryptSecret) > 0 { - config.DecryptSecret = mask + cp := conf.GetGlobalConfig(h.addr) + if cp == nil { + result = server.NewErrorApiResult(contract.NotFound, contract.NotFoundDesc) + } else { + config := *cp + mask := "******" + if len(config.Users) > 0 { + config.Users = mask + } + if len(config.SessionConnection) > 0 { + config.SessionConnection = mask + } + if len(config.EncryptSecret) > 0 { + config.EncryptSecret = mask + } + if len(config.DecryptSecret) > 0 { + config.DecryptSecret = mask + } + result = server.NewApiResult(contract.Success, contract.SuccessDesc, config) } if format == conf.YamlFormat.Name() { - c.YAML(http.StatusOK, server.NewApiResult(contract.Success, contract.SuccessDesc, config)) + c.YAML(http.StatusOK, result) } else { - c.PureJSON(http.StatusOK, server.NewApiResult(contract.Success, contract.SuccessDesc, config)) + c.PureJSON(http.StatusOK, result) } } diff --git a/server/httpfs/file_server.go b/server/httpfs/file_server.go index bd4dcad9..eff31d0d 100644 --- a/server/httpfs/file_server.go +++ b/server/httpfs/file_server.go @@ -208,7 +208,7 @@ func initManageRoute(opt server.Option, logger log.Logger, manageGroup *gin.Rout manageGroup.Use(middleware.NewPrivateAccessHandlerFunc(logger)) } pprof.RouteRegister(manageGroup, server.PProfRoutePrefix) - manageGroup.GET(server.ManageConfigRoute, handler.NewManageHandlerFunc(logger)) + manageGroup.GET(server.ManageConfigRoute, handler.NewManageHandlerFunc(logger, opt.Addr)) if opt.EnableReport { manageGroup.GET(server.ManageReportRoute, handler.NewReportHandlerFunc(logger)) report.GlobalReporter.Enable(true)