Skip to content

Commit

Permalink
Use a better method to reload command presets
Browse files Browse the repository at this point in the history
  • Loading branch information
NI committed Mar 12, 2020
1 parent b2210ce commit 1009cd4
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 370 deletions.
14 changes: 11 additions & 3 deletions application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ func (a Application) run(
) (bool, error) {
var err error

loaderName, c, cErr := cLoader(
a.logger.Context("Configuration"),
commands.Reconfigure)
loaderName, c, cErr := cLoader(a.logger.Context("Configuration"))

if cErr != nil {
a.logger.Error("\"%s\" loader cannot load configuration: %s",
Expand All @@ -83,6 +81,16 @@ func (a Application) run(
return false, cErr
}

// Allowing command to alter presets
c.Presets, err = commands.Reconfigure(c.Presets)

if err != nil {
a.logger.Error("Unable to reconfigure presets: %s", err)

return false, err
}

// Verify all configuration
err = c.Verify()

if err != nil {
Expand Down
40 changes: 29 additions & 11 deletions application/command/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ type Command func(

// Builder builds a command
type Builder struct {
name string
command Command
configurator configuration.Reconfigurator
configurator configuration.PresetReloader
}

// Register builds a Builder for registration
func Register(c Command, p configuration.Reconfigurator) Builder {
func Register(name string, c Command, p configuration.PresetReloader) Builder {
return Builder{
name: name,
command: c,
configurator: p,
}
Expand All @@ -62,7 +64,11 @@ type Commands [MaxCommandID + 1]Builder

// Register registers a new command
func (c *Commands) Register(
id byte, cb Command, ps configuration.Reconfigurator) {
id byte,
name string,
cb Command,
ps configuration.PresetReloader,
) {
if id > MaxCommandID {
panic("Command ID must be not greater than MaxCommandID")
}
Expand All @@ -71,7 +77,7 @@ func (c *Commands) Register(
panic(fmt.Sprintf("Command %d already been registered", id))
}

(*c)[id] = Register(cb, ps)
(*c)[id] = Register(name, cb, ps)
}

// Run creates command executer
Expand All @@ -96,15 +102,27 @@ func (c Commands) Run(

// Reconfigure lets commands reset configuration
func (c Commands) Reconfigure(
p configuration.Configuration,
) configuration.Configuration {
p []configuration.Preset,
) ([]configuration.Preset, error) {
newP := make([]configuration.Preset, 0, len(p))

for i := range c {
if c[i].configurator == nil {
continue
}
for pp := range p {
if c[i].name != p[pp].Type {
continue
}

newPP, pErr := c[i].configurator(p[pp])

p = c[i].configurator(p)
if pErr == nil {
newP = append(newP, newPP)

continue
}

return nil, pErr
}
}

return p
return newP, nil
}
2 changes: 1 addition & 1 deletion application/command/handler_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (d *dummyStreamCommand) Release() error {

func TestHandlerHandleStream(t *testing.T) {
cmds := Commands{}
cmds.Register(0, newDummyStreamCommand, nil)
cmds.Register(0, "name", newDummyStreamCommand, nil)

readerDataInput := make(chan []byte)

Expand Down
4 changes: 2 additions & 2 deletions application/commands/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
// New creates a new commands group
func New() command.Commands {
return command.Commands{
command.Register(newTelnet, parseTelnetConfig),
command.Register(newSSH, parseSSHConfig),
command.Register("Telnet", newTelnet, parseTelnetConfig),
command.Register("SSH", newSSH, parseSSHConfig),
}
}
26 changes: 9 additions & 17 deletions application/commands/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,28 +203,20 @@ func newSSH(
}
}

func parseSSHConfig(p configuration.Configuration) configuration.Configuration {
for i := range p.Presets {
if p.Presets[i].Type != "SSH" {
continue
}

oldHost := p.Presets[i].Host
func parseSSHConfig(p configuration.Preset) (configuration.Preset, error) {
oldHost := p.Host

_, _, sErr := net.SplitHostPort(p.Presets[i].Host)
_, _, sErr := net.SplitHostPort(p.Host)

if sErr != nil {
p.Presets[i].Host = net.JoinHostPort(
p.Presets[i].Host,
sshDefaultPortString)
}
if sErr != nil {
p.Host = net.JoinHostPort(p.Host, sshDefaultPortString)
}

if len(p.Presets[i].Host) <= 0 {
p.Presets[i].Host = oldHost
}
if len(p.Host) <= 0 {
p.Host = oldHost
}

return p
return p, nil
}

func (d *sshClient) Bootup(
Expand Down
28 changes: 9 additions & 19 deletions application/commands/telnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,20 @@ func newTelnet(
}
}

func parseTelnetConfig(
p configuration.Configuration,
) configuration.Configuration {
for i := range p.Presets {
if p.Presets[i].Type != "Telnet" {
continue
}

oldHost := p.Presets[i].Host
func parseTelnetConfig(p configuration.Preset) (configuration.Preset, error) {
oldHost := p.Host

_, _, sErr := net.SplitHostPort(p.Presets[i].Host)
_, _, sErr := net.SplitHostPort(p.Host)

if sErr != nil {
p.Presets[i].Host = net.JoinHostPort(
p.Presets[i].Host,
telnetDefaultPortString)
}
if sErr != nil {
p.Host = net.JoinHostPort(p.Host, telnetDefaultPortString)
}

if len(p.Presets[i].Host) <= 0 {
p.Presets[i].Host = oldHost
}
if len(p.Host) <= 0 {
p.Host = oldHost
}

return p
return p, nil
}

func (d *telnetClient) Bootup(
Expand Down
9 changes: 3 additions & 6 deletions application/configuration/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ import (
"github.com/niruix/sshwifty/application/log"
)

// Reconfigurator reloads configuration
type Reconfigurator func(p Configuration) Configuration
// PresetReloader reloads preset
type PresetReloader func(p Preset) (Preset, error)

// Loader Configuration loader
type Loader func(
log log.Logger,
r Reconfigurator,
) (name string, cfg Configuration, err error)
type Loader func(log log.Logger) (name string, cfg Configuration, err error)
5 changes: 1 addition & 4 deletions application/configuration/loader_direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ const (
// Direct creates a loader that return raw configuration data directly.
// Good for integration.
func Direct(cfg Configuration) Loader {
return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
return func(log log.Logger) (string, Configuration, error) {
return directTypeName, cfg, nil
}
}
9 changes: 3 additions & 6 deletions application/configuration/loader_enviro.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ func parseEviro(name string) string {

// Enviro creates an environment variable based configuration loader
func Enviro() Loader {
return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
return func(log log.Logger) (string, Configuration, error) {
log.Info("Loading configuration from environment variables ...")

dialTimeout, _ := strconv.ParseUint(
Expand Down Expand Up @@ -123,7 +120,7 @@ func Enviro() Loader {
}
}

return enviroTypeName, r(Configuration{
return enviroTypeName, Configuration{
HostName: cfg.HostName,
SharedKey: cfg.SharedKey,
DialTimeout: time.Duration(cfg.DialTimeout) * time.Second,
Expand All @@ -133,6 +130,6 @@ func Enviro() Loader {
Servers: []Server{cfgSer.build()},
Presets: presets,
OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes,
}), nil
}, nil
}
}
18 changes: 6 additions & 12 deletions application/configuration/loader_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ func (f fileCfgCommon) build() (fileCfgCommon, error) {
}, nil
}

func loadFile(
filePath string,
r Reconfigurator,
) (string, Configuration, error) {
func loadFile(filePath string) (string, Configuration, error) {
f, fErr := os.Open(filePath)

if fErr != nil {
Expand Down Expand Up @@ -174,7 +171,7 @@ func loadFile(
presets[i] = finalCfg.Presets[i].build()
}

return fileTypeName, r(Configuration{
return fileTypeName, Configuration{
HostName: finalCfg.HostName,
SharedKey: finalCfg.SharedKey,
DialTimeout: time.Duration(finalCfg.DialTimeout) *
Expand All @@ -185,19 +182,16 @@ func loadFile(
Servers: servers,
Presets: presets,
OnlyAllowPresetRemotes: cfg.OnlyAllowPresetRemotes,
}), nil
}, nil
}

// File creates a configuration file loader
func File(customPath string) Loader {
return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
return func(log log.Logger) (string, Configuration, error) {
if len(customPath) > 0 {
log.Info("Loading configuration from: %s", customPath)

return loadFile(customPath, r)
return loadFile(customPath)
}

log.Info("Loading configuration from one of the default " +
Expand Down Expand Up @@ -239,7 +233,7 @@ func File(customPath string) Loader {
log.Info("Configuration file \"%s\" has been selected",
fallbackFileSearchList[f])

return loadFile(fallbackFileSearchList[f], r)
return loadFile(fallbackFileSearchList[f])
}

return fileTypeName, Configuration{}, fmt.Errorf(
Expand Down
7 changes: 2 additions & 5 deletions application/configuration/loader_redundant.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@ const (
// Redundant creates a group of loaders. They will be executed one by one until
// one of it successfully returned a configuration
func Redundant(loaders ...Loader) Loader {
return func(
log log.Logger,
r Reconfigurator,
) (string, Configuration, error) {
return func(log log.Logger) (string, Configuration, error) {
ll := log.Context("Redundant")

for i := range loaders {
lLoaderName, lCfg, lErr := loaders[i](ll, r)
lLoaderName, lCfg, lErr := loaders[i](ll)

if lErr != nil {
ll.Warning("Unable to load configuration from \"%s\": %s",
Expand Down
Loading

0 comments on commit 1009cd4

Please sign in to comment.