diff --git a/ac/config.go b/ac/config.go index 7402f11a..7f6ecefe 100644 --- a/ac/config.go +++ b/ac/config.go @@ -1,156 +1,225 @@ -package ac - -import ( - "fmt" - "io" - "os" - "path/filepath" - - "github.com/OpenNHP/opennhp/core" - "github.com/OpenNHP/opennhp/log" - "github.com/OpenNHP/opennhp/utils" - - toml "github.com/pelletier/go-toml/v2" -) - -var ( - baseConfigWatch io.Closer - serverConfigWatch io.Closer - - errLoadConfig = fmt.Errorf("config load error") -) - -type Config struct { - PrivateKeyBase64 string `json:"privateKey"` - ACId string `json:"acId"` - DefaultIp string `json:"defaultIp"` - AuthServiceId string `json:"aspId"` - ResourceIds []string `json:"resIds"` - Servers []*core.UdpPeer `json:"servers"` - IpPassMode int `json:"ipPassMode"` // 0: pass the knock source IP, 1: use pre-access mode and release the access source IP - LogLevel int `json:"logLevel"` -} - -type Peers struct { - Servers []*core.UdpPeer -} - -func (d *UdpAC) loadBaseConfig() error { - // config.toml - fileName := filepath.Join(ExeDirPath, "etc", "config.toml") - if err := d.updateBaseConfig(fileName); err != nil { - // report base config error - return err - } - - baseConfigWatch = utils.WatchFile(fileName, func() { - log.Info("base config: %s has been updated", fileName) - d.updateBaseConfig(fileName) - }) - return nil -} - -func (d *UdpAC) loadPeers() error { - // server.toml - fileName := filepath.Join(ExeDirPath, "etc", "server.toml") - if err := d.updateServerPeers(fileName); err != nil { - // ignore error - _ = err - } - - serverConfigWatch = utils.WatchFile(fileName, func() { - log.Info("server peer config: %s has been updated", fileName) - d.updateServerPeers(fileName) - }) - - return nil -} - -func (d *UdpAC) updateBaseConfig(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read base config: %v", err) - } - - var conf Config - if err := toml.Unmarshal(content, &conf); err != nil { - log.Error("failed to unmarshal base config: %v", err) - } - - if d.config == nil { - d.config = &conf - d.log.SetLogLevel(conf.LogLevel) - return err - } - - // update - if d.config.LogLevel != conf.LogLevel { - log.Info("set base log level to %d", conf.LogLevel) - d.log.SetLogLevel(conf.LogLevel) - d.config.LogLevel = conf.LogLevel - } - - if d.config.DefaultIp != conf.DefaultIp { - log.Info("set default ip mode to %s", conf.DefaultIp) - d.config.DefaultIp = conf.DefaultIp - } - - if d.config.IpPassMode != conf.IpPassMode { - log.Info("set ip pass mode to %d", conf.IpPassMode) - d.config.IpPassMode = conf.IpPassMode - } - - return err -} - -func (d *UdpAC) updateServerPeers(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read server peer config: %v", err) - } - - // update - var peers Peers - serverPeerMap := make(map[string]*core.UdpPeer) - if err := toml.Unmarshal(content, &peers); err != nil { - log.Error("failed to unmarshal server peer config: %v", err) - } - for _, p := range peers.Servers { - p.Type = core.NHP_SERVER - d.device.AddPeer(p) - serverPeerMap[p.PublicKeyBase64()] = p - } - - // remove old peers from device - d.serverPeerMutex.Lock() - defer d.serverPeerMutex.Unlock() - for pubKey := range d.serverPeerMap { - if _, found := serverPeerMap[pubKey]; !found { - d.device.RemovePeer(pubKey) - } - } - d.serverPeerMap = serverPeerMap - - return err -} - -func (d *UdpAC) IpPassMode() int { - return d.config.IpPassMode -} - -func (d *UdpAC) StopConfigWatch() { - if baseConfigWatch != nil { - baseConfigWatch.Close() - } - if serverConfigWatch != nil { - serverConfigWatch.Close() - } -} +package ac + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/OpenNHP/opennhp/core" + "github.com/OpenNHP/opennhp/log" + "github.com/OpenNHP/opennhp/utils" + + toml "github.com/pelletier/go-toml/v2" +) + +var ( + baseConfigWatch io.Closer + httpConfigWatch io.Closer + serverPeerWatch io.Closer + + errLoadConfig = fmt.Errorf("config load error") +) + +type Config struct { + PrivateKeyBase64 string `json:"privateKey"` + ACId string `json:"acId"` + DefaultIp string `json:"defaultIp"` + AuthServiceId string `json:"aspId"` + ResourceIds []string `json:"resIds"` + Servers []*core.UdpPeer `json:"servers"` + IpPassMode int `json:"ipPassMode"` // 0: pass the knock source IP, 1: use pre-access mode and release the access source IP + LogLevel int `json:"logLevel"` +} + +type HttpConfig struct { + EnableHttp bool + EnableTLS bool + HttpListenPort int + TLSCertFile string + TLSKeyFile string +} + +type Peers struct { + Servers []*core.UdpPeer +} + +func (a *UdpAC) loadBaseConfig() error { + // config.toml + fileName := filepath.Join(ExeDirPath, "etc", "config.toml") + if err := a.updateBaseConfig(fileName); err != nil { + // report base config error + return err + } + + baseConfigWatch = utils.WatchFile(fileName, func() { + log.Info("base config: %s has been updated", fileName) + a.updateBaseConfig(fileName) + }) + return nil +} + +func (a *UdpAC) loadHttpConfig() error { + // http.toml + fileName := filepath.Join(ExeDirPath, "etc", "http.toml") + if err := a.updateHttpConfig(fileName); err != nil { + // ignore error + _ = err + } + + httpConfigWatch = utils.WatchFile(fileName, func() { + log.Info("http config: %s has been updated", fileName) + a.updateHttpConfig(fileName) + }) + return nil +} + +func (a *UdpAC) loadPeers() error { + // server.toml + fileName := filepath.Join(ExeDirPath, "etc", "server.toml") + if err := a.updateServerPeers(fileName); err != nil { + // ignore error + _ = err + } + + serverPeerWatch = utils.WatchFile(fileName, func() { + log.Info("server peer config: %s has been updated", fileName) + a.updateServerPeers(fileName) + }) + + return nil +} + +func (a *UdpAC) updateBaseConfig(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read base config: %v", err) + } + + var conf Config + if err := toml.Unmarshal(content, &conf); err != nil { + log.Error("failed to unmarshal base config: %v", err) + } + + if a.config == nil { + a.config = &conf + a.log.SetLogLevel(conf.LogLevel) + return err + } + + // update + if a.config.LogLevel != conf.LogLevel { + log.Info("set base log level to %d", conf.LogLevel) + a.log.SetLogLevel(conf.LogLevel) + a.config.LogLevel = conf.LogLevel + } + + if a.config.DefaultIp != conf.DefaultIp { + log.Info("set default ip mode to %s", conf.DefaultIp) + a.config.DefaultIp = conf.DefaultIp + } + + if a.config.IpPassMode != conf.IpPassMode { + log.Info("set ip pass mode to %d", conf.IpPassMode) + a.config.IpPassMode = conf.IpPassMode + } + + return err +} + +func (a *UdpAC) updateHttpConfig(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read http config: %v", err) + } + + var httpConf HttpConfig + if err := toml.Unmarshal(content, &httpConf); err != nil { + log.Error("failed to unmarshal http config: %v", err) + } + + // update + if httpConf.EnableHttp { + // start http server + if a.httpServer == nil || !a.httpServer.IsRunning() { + if a.httpServer != nil { + // stop old http server + go a.httpServer.Stop() + } + hs := &HttpAC{} + a.httpServer = hs + err = hs.Start(a, &httpConf) + if err != nil { + return err + } + } + } else { + // stop http server + if a.httpServer != nil && a.httpServer.IsRunning() { + go a.httpServer.Stop() + a.httpServer = nil + } + } + + a.httpConfig = &httpConf + return err +} + +func (a *UdpAC) updateServerPeers(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read server peer config: %v", err) + } + + // update + var peers Peers + serverPeerMap := make(map[string]*core.UdpPeer) + if err := toml.Unmarshal(content, &peers); err != nil { + log.Error("failed to unmarshal server peer config: %v", err) + } + for _, p := range peers.Servers { + p.Type = core.NHP_SERVER + a.device.AddPeer(p) + serverPeerMap[p.PublicKeyBase64()] = p + } + + // remove old peers from device + a.serverPeerMutex.Lock() + defer a.serverPeerMutex.Unlock() + for pubKey := range a.serverPeerMap { + if _, found := serverPeerMap[pubKey]; !found { + a.device.RemovePeer(pubKey) + } + } + a.serverPeerMap = serverPeerMap + + return err +} + +func (a *UdpAC) IpPassMode() int { + return a.config.IpPassMode +} + +func (a *UdpAC) StopConfigWatch() { + if baseConfigWatch != nil { + baseConfigWatch.Close() + } + if httpConfigWatch != nil { + httpConfigWatch.Close() + } + if serverPeerWatch != nil { + serverPeerWatch.Close() + } +} diff --git a/ac/constants.go b/ac/constants.go index 500f3d6d..db52786e 100644 --- a/ac/constants.go +++ b/ac/constants.go @@ -10,7 +10,8 @@ const ( ServerKeepaliveInterval = 20 // seconds ServerDiscoveryRetryBeforeFail = 3 - TempPortOpenTime = 30 // + TokenStoreRefreshInterval = 10 + TempPortOpenTime = 30 // IPSET_DEFAULT_NAME = "defaultset" IPSET_DEFAULT_DOWN_NAME = "defaultset_down" diff --git a/ac/httpac.go b/ac/httpac.go new file mode 100644 index 00000000..70b7409a --- /dev/null +++ b/ac/httpac.go @@ -0,0 +1,208 @@ +package ac + +import ( + "context" + "encoding/base64" + "net" + "net/http" + "os" + "path/filepath" + "sync" + "sync/atomic" + "time" + + "github.com/OpenNHP/opennhp/common" + "github.com/OpenNHP/opennhp/log" + "github.com/gin-gonic/gin" +) + +type HttpAC struct { + id string + ua *UdpAC + httpServer *http.Server + ginEngine *gin.Engine + listenAddr *net.TCPAddr + + wg sync.WaitGroup + running atomic.Bool + + // signals + signals struct { + stop chan struct{} + } +} + +// Note HttpServer must be started after starting UdpAC, when log and config have been setup +func (hs *HttpAC) Start(uac *UdpAC, hc *HttpConfig) error { + hs.id = time.Now().Format("2006-01-02 15:04:05") + log.Info("==================================================") + log.Info("=== HttpServer (%s) started ===", hs.id) + log.Info("==================================================") + + hs.ua = uac + + port := hc.HttpListenPort + if hc.HttpListenPort == 0 { + port = 62206 + } + // only listen to localhost for security reason. + hs.listenAddr = &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: port, + } + + hs.signals.stop = make(chan struct{}) + + gin.SetMode(gin.ReleaseMode) + hs.ginEngine = gin.New() + hs.ginEngine.Use(gin.LoggerWithWriter(uac.log.Writer())) + hs.ginEngine.Use(gin.Recovery()) + + hs.initRouter() + + hs.httpServer = &http.Server{ + Addr: hs.listenAddr.String(), + Handler: hs.ginEngine, + ReadTimeout: 4500 * time.Millisecond, + WriteTimeout: 4000 * time.Millisecond, + IdleTimeout: 5000 * time.Millisecond, + } + + hs.wg.Add(1) + if hc.EnableTLS { + certFilePath := filepath.Join(ExeDirPath, hc.TLSCertFile) + keyFilePath := filepath.Join(ExeDirPath, hc.TLSKeyFile) + _, err1 := os.Stat(certFilePath) + _, err2 := os.Stat(keyFilePath) + if err1 == nil && err2 == nil { + go func() { + defer hs.wg.Done() + log.Info("Listening https on %s", hs.listenAddr.String()) + var err = hs.httpServer.ListenAndServeTLS(certFilePath, keyFilePath) + if err != nil && err != http.ErrServerClosed { + log.Error("https server close error: %v\n", err) + //panic(err) + } + }() + + return nil + } + } + + go func() { + defer hs.wg.Done() + log.Info("Listening http on %s", hs.listenAddr.String()) + var err = hs.httpServer.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + log.Error("http server close error: %v\n", err) + //panic(err) + } + }() + + hs.running.Store(true) + return nil +} + +// Stop stops the HttpServer by setting the running flag to false, +// closing the stop channel, shutting down the underlying http server, +// waiting for all goroutines to finish, and logging a message indicating +// that the HttpServer has been stopped. +func (hs *HttpAC) Stop() { + if !hs.running.Load() { + // already stopped, do nothing + return + } + + hs.running.Store(false) + close(hs.signals.stop) + ctx, cancel := context.WithTimeout(context.Background(), 5500*time.Millisecond) + hs.httpServer.Shutdown(ctx) + + hs.wg.Wait() + cancel() + cancel = nil + log.Info("==================================================") + log.Info("=== HttpServer (%s) stopped ===", hs.id) + log.Info("==================================================") +} + +func (hs *HttpAC) IsRunning() bool { + return hs.running.Load() +} + +// init gin engine. Must be called at initialization +func (ha *HttpAC) initRouter() { + g := ha.ginEngine + + pluginGrp := g.Group("refresh") + // display login page with templates + pluginGrp.GET("/:token", func(ctx *gin.Context) { + var err error + token := ctx.Param("token") + log.Info("get refresh request. aspId: %s, query: %v", token, ctx.Request.URL.RawQuery) + + if len(token) == 0 { + err = common.ErrUrlPathInvalid + log.Error("path error: %v", err) + ctx.String(http.StatusOK, "{\"errMsg\": \"path error: %v\"}", err) + return + } + + req := &common.HttpRefreshRequest{ + Token: token, + SrcIp: ctx.Query("srcip"), + } + + ha.HandleHttpRefreshOperations(ctx, req) + }) +} + +func (ha *HttpAC) HandleHttpRefreshOperations(c *gin.Context, req *common.HttpRefreshRequest) { + if len(req.SrcIp) == 0 { + c.String(http.StatusOK, "{\"errMsg\": \"empty source ip\"}") + return + } + + netIp := net.ParseIP(req.SrcIp) + if netIp == nil { + c.String(http.StatusOK, "{\"errMsg\": \"invalid source ip\"}") + return + } + + buf, err := base64.StdEncoding.DecodeString(req.Token) + if err != nil || len(buf) != 32 { + c.String(http.StatusOK, "{\"errMsg\": \"invalid token\"}") + return + } + + entry := ha.ua.VerifyAccessToken(req.Token) + if entry == nil { + c.String(http.StatusOK, "{\"errMsg\": \"token verification failed\"}") + return + } + + var found bool + var newSrcAddr *common.NetAddress + for _, addr := range entry.SrcAddrs { + if addr.Ip == req.SrcIp { + found = true + break + } + } + if !found { + newSrcAddr = &common.NetAddress{ + Ip: req.SrcIp, + Port: entry.SrcAddrs[0].Port, + Protocol: entry.SrcAddrs[0].Protocol, + } + entry.SrcAddrs = append(entry.SrcAddrs, newSrcAddr) + } + + _, err = ha.ua.HandleAccessControl(entry.AgentUser, entry.SrcAddrs, entry.DstAddrs, entry.OpenTime, nil) + if err != nil { + c.String(http.StatusOK, "{\"errMsg\": \"%s\"}", err) + return + } + + c.JSON(http.StatusOK, entry) +} diff --git a/ac/msghandler.go b/ac/msghandler.go index beb970aa..30e58c62 100644 --- a/ac/msghandler.go +++ b/ac/msghandler.go @@ -21,373 +21,399 @@ const ( PASS_PRE_ACCESS_IP ) -func (d *UdpAC) HandleACOperations(ppd *core.PacketParserData) (err error) { - defer d.wg.Done() - d.wg.Add(1) +func (a *UdpAC) HandleUdpACOperations(ppd *core.PacketParserData) (err error) { + a.wg.Add(1) + defer a.wg.Done() - acId := d.config.ACId + acId := a.config.ACId dopMsg := &common.ServerACOpsMsg{} artMsg := &common.ACOpsResultMsg{} transactionId := ppd.SenderTrxId + err = json.Unmarshal(ppd.BodyMessage, dopMsg) + if err != nil { + log.Error("ac(%s#%d)[HandleUdpACOperations] failed to parse %s message: %v", acId, transactionId, core.HeaderTypeToString(ppd.HeaderType), err) + artMsg.ErrCode = common.ErrJsonParseFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } + + srcAddrs := dopMsg.SourceAddrs + dstAddrs := dopMsg.DestinationAddrs + openTimeSec := int(dopMsg.OpenTime) + agentUser := &AgentUser{ + UserId: dopMsg.UserId, + DeviceId: dopMsg.DeviceId, + OrganizationId: dopMsg.OrganizationId, + AuthServiceId: dopMsg.AuthServiceId, + } + artMsg, err = a.HandleAccessControl(agentUser, srcAddrs, dstAddrs, openTimeSec, artMsg) + if err != nil { + log.Error("ac(%s#%d)[HandleUdpACOperations] HandleAccessControl failed, err: %v", acId, err) + } + + // generate ac token and save user and access information + entry := &AccessEntry{ + AgentUser: agentUser, + SrcAddrs: srcAddrs, + DstAddrs: dstAddrs, + OpenTime: openTimeSec, + } + artMsg.ACToken = a.GenerateAccessToken(entry) + + // send ac result + artBytes, _ := json.Marshal(artMsg) + md := &core.MsgData{ + HeaderType: core.NHP_ART, + TransactionId: transactionId, + Compress: true, + PrevParserData: ppd, + Message: artBytes, + } + + // forward to a specific transaction + transaction := ppd.ConnData.FindRemoteTransaction(transactionId) + if transaction == nil { + log.Error("ac(%s#%d)[HandleUdpACOperations] transaction is not available", acId, transactionId) + err = common.ErrTransactionIdNotFound + return err + } + + transaction.NextMsgCh <- md + + return err +} + +func (a *UdpAC) HandleAccessControl(au *AgentUser, srcAddrs []*common.NetAddress, dstAddrs []*common.NetAddress, openTimeSec int, artMsgIn *common.ACOpsResultMsg) (artMsg *common.ACOpsResultMsg, err error) { + if artMsgIn == nil { + artMsg = &common.ACOpsResultMsg{} + } else { + artMsg = artMsgIn + } + // process ac operation - func() { - err = json.Unmarshal(ppd.BodyMessage, dopMsg) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] failed to parse %s message: %v", acId, transactionId, core.HeaderTypeToString(ppd.HeaderType), err) - artMsg.ErrCode = common.ErrJsonParseFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + tempOpenTimeSec := TempPortOpenTime + // 1 sec timeout means exit defaultset access, so exit tempset too + if openTimeSec == 1 { + tempOpenTimeSec = 1 + } - srcAddrs := dopMsg.SourceAddrs - dstAddrs := dopMsg.DestinationAddrs - openTimeSec := int(dopMsg.OpenTime) - tempOpenTimeSec := TempPortOpenTime - // 1 sec timeout means exit defaultset access, so exit tempset too - if openTimeSec == 1 { - tempOpenTimeSec = 1 - } + // check empty src address + if len(srcAddrs) == 0 || len(dstAddrs) == 0 { + log.Error("[HandleAccessControl] no source or destination address specified") + err = common.ErrACEmptyPassAddress + artMsg.ErrCode = common.ErrACEmptyPassAddress.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - // check empty src address - if len(srcAddrs) == 0 || len(dstAddrs) == 0 { - log.Error("ac(%s#%d)[HandleACOperations] no source or destination address specified", acId, transactionId) - err = common.ErrACEmptyPassAddress - artMsg.ErrCode = common.ErrACEmptyPassAddress.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + // ac ipset operations + if a.ipset == nil { + log.Error("[HandleAccessControl] ipset is nil") + err = common.ErrACIPSetNotFound + artMsg.ErrCode = common.ErrACIPSetNotFound.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - // ac ipset operations - if d.ipset == nil { - log.Error("ac(%s#%d)[HandleACOperations] ipset is nil", acId, transactionId) - err = common.ErrACIPSetNotFound - artMsg.ErrCode = common.ErrACIPSetNotFound.ErrorCode() - artMsg.ErrMsg = err.Error() - return + // use ac default ip to override empty destination ip + if len(a.config.DefaultIp) > 0 { + for _, addr := range dstAddrs { + if len(addr.Ip) == 0 { + addr.Ip = a.config.DefaultIp + } } + } - // use ac default ip to override empty destination ip - if len(d.config.DefaultIp) > 0 { - for _, addr := range dstAddrs { - if len(addr.Ip) == 0 { - addr.Ip = d.config.DefaultIp - } + switch a.IpPassMode() { + // pass the knock ip immediately + case PASS_KNOCK_IP: + for _, srcAddr := range srcAddrs { + var ipType utils.IPTYPE + var ipNet *net.IPNet + if strings.Contains(srcAddr.Ip, ":") { + ipType = utils.IPV6 + _, ipNet, _ = net.ParseCIDR(srcAddr.Ip + "/121") + } else { + ipType = utils.IPV4 + _, ipNet, _ = net.ParseCIDR(srcAddr.Ip + "/25") } - } + log.Debug("src ip is %s, net range is %s", srcAddr, ipNet.String()) + + for _, dstAddr := range dstAddrs { + // for tcp + if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "tcp" || dstAddr.Protocol == "any" { + ipHashStr := fmt.Sprintf("%s,%d,%s", srcAddr.Ip, dstAddr.Port, dstAddr.Ip) + if dstAddr.Port == 0 { + ipHashStr = fmt.Sprintf("%s,1-65535,%s", srcAddr.Ip, dstAddr.Ip) + } - switch d.IpPassMode() { - // pass the knock ip immediately - case PASS_KNOCK_IP: - for _, srcAddr := range srcAddrs { - var ipType utils.IPTYPE - var ipNet *net.IPNet - if strings.Contains(srcAddr.Ip, ":") { - ipType = utils.IPV6 - _, ipNet, _ = net.ParseCIDR(srcAddr.Ip + "/121") - } else { - ipType = utils.IPV4 - _, ipNet, _ = net.ParseCIDR(srcAddr.Ip + "/25") + _, err = a.ipset.Add(ipType, 1, openTimeSec, ipHashStr) + if err != nil { + log.Error("[HandleAccessControl] add ipset %s error: %v", ipHashStr, err) + err = common.ErrACIPSetOperationFailed + artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } } - log.Debug("src ip is %s, net range is %s", srcAddr, ipNet.String()) - for _, dstAddr := range dstAddrs { - // for tcp - if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "tcp" || dstAddr.Protocol == "any" { - ipHashStr := fmt.Sprintf("%s,%d,%s", srcAddr.Ip, dstAddr.Port, dstAddr.Ip) - if dstAddr.Port == 0 { - ipHashStr = fmt.Sprintf("%s,1-65535,%s", srcAddr.Ip, dstAddr.Ip) - } + // for udp + if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "udp" || dstAddr.Protocol == "any" { + ipHashStr := fmt.Sprintf("%s,udp:%d,%s", srcAddr.Ip, dstAddr.Port, dstAddr.Ip) + if dstAddr.Port == 0 { + ipHashStr = fmt.Sprintf("%s,udp:1-65535,%s", srcAddr.Ip, dstAddr.Ip) + } - _, err = d.ipset.Add(ipType, 1, openTimeSec, ipHashStr) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, ipHashStr, err) - err = common.ErrACIPSetOperationFailed - artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + _, err = a.ipset.Add(ipType, 1, openTimeSec, ipHashStr) + if err != nil { + log.Error("[HandleAccessControl] add ipset %s error: %v", ipHashStr, err) + err = common.ErrACIPSetOperationFailed + artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return } + } - // for udp - if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "udp" || dstAddr.Protocol == "any" { - ipHashStr := fmt.Sprintf("%s,udp:%d,%s", srcAddr.Ip, dstAddr.Port, dstAddr.Ip) - if dstAddr.Port == 0 { - ipHashStr = fmt.Sprintf("%s,udp:1-65535,%s", srcAddr.Ip, dstAddr.Ip) - } + // for icmp ping + if dstAddr.Port == 0 && (len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "any") { + for _, dstAddr := range dstAddrs { + ipHashStr := fmt.Sprintf("%s,icmp:8/0,%s", srcAddr.Ip, dstAddr.Ip) - _, err = d.ipset.Add(ipType, 1, openTimeSec, ipHashStr) + _, err = a.ipset.Add(ipType, 1, openTimeSec, ipHashStr) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, ipHashStr, err) + log.Error("[HandleAccessControl] add ipset %s error: %v", ipHashStr, err) err = common.ErrACIPSetOperationFailed artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() artMsg.ErrMsg = err.Error() return } } + } - // for icmp ping - if dstAddr.Port == 0 && (len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "any") { - for _, dstAddr := range dstAddrs { - ipHashStr := fmt.Sprintf("%s,icmp:8/0,%s", srcAddr.Ip, dstAddr.Ip) - - _, err = d.ipset.Add(ipType, 1, openTimeSec, ipHashStr) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, ipHashStr, err) - err = common.ErrACIPSetOperationFailed - artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + // add tempset + if ipNet != nil { + netStr := ipNet.String() + if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "tcp" || dstAddr.Protocol == "any" { + netHashStr := fmt.Sprintf("%s,%d", netStr, dstAddr.Port) + if dstAddr.Port == 0 { + netHashStr = fmt.Sprintf("%s,1-65535", netStr) } + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, netHashStr) } - // add tempset - if ipNet != nil { - netStr := ipNet.String() - if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "tcp" || dstAddr.Protocol == "any" { - netHashStr := fmt.Sprintf("%s,%d", netStr, dstAddr.Port) - if dstAddr.Port == 0 { - netHashStr = fmt.Sprintf("%s,1-65535", netStr) - } - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, netHashStr) - } - - if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "udp" || dstAddr.Protocol == "any" { - netHashStr := fmt.Sprintf("%s,udp:%d", netStr, dstAddr.Port) - if dstAddr.Port == 0 { - netHashStr = fmt.Sprintf("%s,udp:1-65535", netStr) - } - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, netHashStr) + if len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "udp" || dstAddr.Protocol == "any" { + netHashStr := fmt.Sprintf("%s,udp:%d", netStr, dstAddr.Port) + if dstAddr.Port == 0 { + netHashStr = fmt.Sprintf("%s,udp:1-65535", netStr) } + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, netHashStr) + } - if dstAddr.Port == 0 && (len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "any") { - netHashStr := fmt.Sprintf("%s,icmp:8/0", netStr) - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, netHashStr) - } + if dstAddr.Port == 0 && (len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "any") { + netHashStr := fmt.Sprintf("%s,icmp:8/0", netStr) + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, netHashStr) } } } + } - // return temporary listened port(s) and nhp access token, then pass the real ip when agent sends access message - case PASS_PRE_ACCESS_IP: - fallthrough - default: - // ac open a temporary tcp or udp port for access - dstIp := net.ParseIP(dstAddrs[0].Ip) - if dstIp == nil { - log.Error("ac(%s#%d)[HandleACOperations] destination IP %s is invalid", acId, transactionId, dstAddrs[0].Ip) - err = common.ErrInvalidIpAddress - artMsg.ErrCode = common.ErrInvalidIpAddress.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } - - var ipType utils.IPTYPE - var netStr string - var netStr1 string - var pickedPort int - var tcpListener *net.TCPListener - var udpListener *net.UDPConn - - if strings.Contains(dstAddrs[0].Ip, ":") { - ipType = utils.IPV6 - netStr = "0:0:0:0:0:0:0:0/0" - } else { - // since ipset does not allow full ip range 0.0.0.0/0, we use two ip ranges - ipType = utils.IPV4 - netStr = "0.0.0.0/1" - netStr1 = "128.0.0.0/1" - } - - // openning temp tcp access - tcpListener, err = net.ListenTCP("tcp", &net.TCPAddr{ - IP: dstIp, - Port: 0, // ephemeral port - }) + // return temporary listened port(s) and nhp access token, then pass the real ip when agent sends access message + case PASS_PRE_ACCESS_IP: + fallthrough + default: + // ac open a temporary tcp or udp port for access + dstIp := net.ParseIP(dstAddrs[0].Ip) + if dstIp == nil { + log.Error("[HandleAccessControl] destination IP %s is invalid", dstAddrs[0].Ip) + err = common.ErrInvalidIpAddress + artMsg.ErrCode = common.ErrInvalidIpAddress.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] temporary tcp listening error: %v", acId, transactionId, err) - err = common.ErrACTempPortListenFailed - artMsg.ErrCode = common.ErrACTempPortListenFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + var ipType utils.IPTYPE + var netStr string + var netStr1 string + var pickedPort int + var tcpListener *net.TCPListener + var udpListener *net.UDPConn - // retrieve local port - tladdr := tcpListener.Addr() - tlocalAddr, locErr := net.ResolveTCPAddr(tladdr.Network(), tladdr.String()) - if locErr != nil { - log.Error("ac(%s#%d)[HandleACOperations] resolve local TCPAddr error: %v", acId, transactionId, locErr) - err = common.ErrACResolveTempPortFailed - artMsg.ErrCode = common.ErrACResolveTempPortFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + if strings.Contains(dstAddrs[0].Ip, ":") { + ipType = utils.IPV6 + netStr = "0:0:0:0:0:0:0:0/0" + } else { + // since ipset does not allow full ip range 0.0.0.0/0, we use two ip ranges + ipType = utils.IPV4 + netStr = "0.0.0.0/1" + netStr1 = "128.0.0.0/1" + } - log.Debug("open temporary tcp port %s", tlocalAddr.String()) - portHashStr := fmt.Sprintf("%s,%d", netStr, tlocalAddr.Port) - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, portHashStr, err) - err = common.ErrACIPSetOperationFailed - artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } - portHashStr = fmt.Sprintf("%s,%d", netStr1, tlocalAddr.Port) - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, portHashStr, err) - err = common.ErrACIPSetOperationFailed - artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + // openning temp tcp access + tcpListener, err = net.ListenTCP("tcp", &net.TCPAddr{ + IP: dstIp, + Port: 0, // ephemeral port + }) - pickedPort = tlocalAddr.Port - log.Info("ac(%s#%d)[HandleACOperations] open temporary tcp port on %s", acId, transactionId, tladdr.String()) + if err != nil { + log.Error("[HandleAccessControl] temporary tcp listening error: %v", err) + err = common.ErrACTempPortListenFailed + artMsg.ErrCode = common.ErrACTempPortListenFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - // for temp udp access - udpListener, err = net.ListenUDP("udp", &net.UDPAddr{ - IP: dstIp, - Port: pickedPort, // ephemeral port(0) or continue with previously picked tcp port - }) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] temporary udp listening error: %v", acId, transactionId, err) - err = common.ErrACTempPortListenFailed - artMsg.ErrCode = common.ErrACTempPortListenFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + // retrieve local port + tladdr := tcpListener.Addr() + tlocalAddr, locErr := net.ResolveTCPAddr(tladdr.Network(), tladdr.String()) + if locErr != nil { + log.Error("[HandleAccessControl] resolve local TCPAddr error: %v", locErr) + err = common.ErrACResolveTempPortFailed + artMsg.ErrCode = common.ErrACResolveTempPortFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - // retrieve local port - uladdr := udpListener.LocalAddr() - _, locErr = net.ResolveUDPAddr(uladdr.Network(), uladdr.String()) - if locErr != nil { - log.Error("ac(%s#%d)[HandleACOperations] resolve local UDPAddr error: %v", acId, transactionId, locErr) - err = common.ErrACResolveTempPortFailed - artMsg.ErrCode = common.ErrACResolveTempPortFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + log.Debug("open temporary tcp port %s", tlocalAddr.String()) + portHashStr := fmt.Sprintf("%s,%d", netStr, tlocalAddr.Port) + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) + if err != nil { + log.Error("[HandleAccessControl] add ipset %s error: %v", portHashStr, err) + err = common.ErrACIPSetOperationFailed + artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } + portHashStr = fmt.Sprintf("%s,%d", netStr1, tlocalAddr.Port) + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) + if err != nil { + log.Error("[HandleAccessControl] add ipset %s error: %v", portHashStr, err) + err = common.ErrACIPSetOperationFailed + artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - log.Debug("open temporary udp port %s", tlocalAddr.String()) - pickedPort = tlocalAddr.Port - portHashStr = fmt.Sprintf("%s,udp:%d", netStr, tlocalAddr.Port) - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, portHashStr, err) - err = common.ErrACIPSetOperationFailed - artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } - portHashStr = fmt.Sprintf("%s,udp:%d", netStr1, tlocalAddr.Port) - _, err = d.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) - if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, portHashStr, err) - err = common.ErrACIPSetOperationFailed - artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() - artMsg.ErrMsg = err.Error() - return - } + pickedPort = tlocalAddr.Port + log.Info("[HandleAccessControl] open temporary tcp port on %s", tladdr.String()) - log.Info("ac(%s#%d)[HandleACOperations] open temporary udp port on %s", acId, transactionId, tladdr.String()) + // for temp udp access + udpListener, err = net.ListenUDP("udp", &net.UDPAddr{ + IP: dstIp, + Port: pickedPort, // ephemeral port(0) or continue with previously picked tcp port + }) + if err != nil { + log.Error("[HandleAccessControl] temporary udp listening error: %v", err) + err = common.ErrACTempPortListenFailed + artMsg.ErrCode = common.ErrACTempPortListenFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - agentUser := &AgentUser{ - UserId: dopMsg.UserId, - DeviceId: dopMsg.DeviceId, - OrganizationId: dopMsg.OrganizationId, - } + // retrieve local port + uladdr := udpListener.LocalAddr() + _, locErr = net.ResolveUDPAddr(uladdr.Network(), uladdr.String()) + if locErr != nil { + log.Error("[HandleAccessControl] resolve local UDPAddr error: %v", locErr) + err = common.ErrACResolveTempPortFailed + artMsg.ErrCode = common.ErrACResolveTempPortFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - artMsg.PreAccessAction = &common.PreAccessInfo{ - AccessPort: strconv.Itoa(pickedPort), - ACPubKey: d.device.PublicKeyExBase64(), - ACToken: d.GenerateAccessToken(agentUser), - } + log.Debug("open temporary udp port %s", tlocalAddr.String()) + pickedPort = tlocalAddr.Port + portHashStr = fmt.Sprintf("%s,udp:%d", netStr, tlocalAddr.Port) + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) + if err != nil { + log.Error("[HandleAccessControl] add ipset %s error: %v", portHashStr, err) + err = common.ErrACIPSetOperationFailed + artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } + portHashStr = fmt.Sprintf("%s,udp:%d", netStr1, tlocalAddr.Port) + _, err = a.ipset.Add(ipType, 4, tempOpenTimeSec, portHashStr) + if err != nil { + log.Error("[HandleAccessControl] add ipset %s error: %v", portHashStr, err) + err = common.ErrACIPSetOperationFailed + artMsg.ErrCode = common.ErrACIPSetOperationFailed.ErrorCode() + artMsg.ErrMsg = err.Error() + return + } - if tcpListener != nil { - d.wg.Add(1) - go d.tcpTempAccessHandler(transactionId, tcpListener, agentUser, tempOpenTimeSec, dopMsg) - } + log.Info("[HandleAccessControl] open temporary udp port on %s", tladdr.String()) - if udpListener != nil { - d.wg.Add(1) - go d.udpTempAccessHandler(transactionId, udpListener, agentUser, tempOpenTimeSec, dopMsg) - } + tempEntry := &AccessEntry{ + AgentUser: au, + SrcAddrs: srcAddrs, + DstAddrs: dstAddrs, + OpenTime: tempOpenTimeSec, + } + artMsg.PreAccessAction = &common.PreAccessInfo{ + AccessPort: strconv.Itoa(pickedPort), + ACPubKey: a.device.PublicKeyExBase64(), + ACToken: a.GenerateAccessToken(tempEntry), } - log.Info("ac(%s#%d)[HandleACOperations] succeed", acId, transactionId) - artMsg.ErrCode = common.ErrSuccess.ErrorCode() - artMsg.OpenTime = dopMsg.OpenTime - }() - - // send ac result - artBytes, _ := json.Marshal(artMsg) + if tcpListener != nil { + a.wg.Add(1) + go a.tcpTempAccessHandler(tcpListener, tempOpenTimeSec, dstAddrs, openTimeSec) + } - md := &core.MsgData{ - HeaderType: core.NHP_ART, - TransactionId: transactionId, - Compress: true, - PrevParserData: ppd, - Message: artBytes, + if udpListener != nil { + a.wg.Add(1) + go a.udpTempAccessHandler(udpListener, tempOpenTimeSec, dstAddrs, openTimeSec) + } } - // forward to a specific transaction - transaction := ppd.ConnData.FindRemoteTransaction(transactionId) - if transaction == nil { - log.Error("ac(%s#%d)[HandleACOperations] transaction is not available", acId, transactionId) - err = common.ErrTransactionIdNotFound - return err - } + log.Info("[HandleAccessControl] succeed") - transaction.NextMsgCh <- md + artMsg.ErrCode = common.ErrSuccess.ErrorCode() + artMsg.OpenTime = uint32(openTimeSec) - return nil + return } -func (d *UdpAC) tcpTempAccessHandler(transactionId uint64, listener *net.TCPListener, au *AgentUser, timeoutSec int, dopMsg *common.ServerACOpsMsg) { - defer d.wg.Done() - defer d.DeleteAccessToken(au) +func (a *UdpAC) tcpTempAccessHandler(listener *net.TCPListener, timeoutSec int, dstAddrs []*common.NetAddress, openTimeSec int) { + defer a.wg.Done() defer listener.Close() - acId := d.config.ACId // accept only the first incoming tcp connection startTime := time.Now() deadlineTime := startTime.Add(time.Duration(timeoutSec) * time.Second) localAddrStr := listener.Addr().String() err := listener.SetDeadline(deadlineTime) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] temporary port on %s failed to set tcp listen timeout", acId, transactionId, localAddrStr) + log.Error("[tcpTempAccessHandler] temporary port on %s failed to set tcp listen timeout", localAddrStr) return } conn, err := listener.Accept() if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] temporary port on %s tcp listen timeout", acId, transactionId, localAddrStr) + log.Error("[tcpTempAccessHandler] temporary port on %s tcp listen timeout", localAddrStr) return } defer conn.Close() err = conn.SetDeadline(deadlineTime) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] temporary port on %s failed to set tcp conn timeout", acId, transactionId, localAddrStr) + log.Error("[tcpTempAccessHandler] temporary port on %s failed to set tcp conn timeout", localAddrStr) return } remoteAddrStr := conn.RemoteAddr().String() - pkt := d.device.AllocatePoolPacket() - defer d.device.ReleasePoolPacket(pkt) + pkt := a.device.AllocatePoolPacket() + defer a.device.ReleasePoolPacket(pkt) // monitor stop signals and quit connection earlier ctx, ctxCancel := context.WithDeadline(context.Background(), deadlineTime) defer ctxCancel() - go d.tempConnTerminator(conn, ctx) + go a.tempConnTerminator(conn, ctx) // tcp recv common header first n, err := conn.Read(pkt.Buf[:core.HeaderCommonSize]) if err != nil || n < core.HeaderCommonSize { - log.Error("ac(%s#%d)[HandleACOperations] failed to receive tcp packet header from remote address %s (%v)", acId, transactionId, remoteAddrStr, err) + log.Error("[tcpTempAccessHandler] failed to receive tcp packet header from remote address %s (%v)", remoteAddrStr, err) return } @@ -395,7 +421,7 @@ func (d *UdpAC) tcpTempAccessHandler(transactionId uint64, listener *net.TCPList // check type and payload size msgType, msgSize := pkt.HeaderTypeAndSize() if msgType != core.NHP_ACC { - log.Error("ac(%s#%d)[HandleACOperations] message type is not %s, close connection", acId, transactionId, core.HeaderTypeToString(core.NHP_ACC)) + log.Error("[tcpTempAccessHandler] message type is not %s, close connection", core.HeaderTypeToString(core.NHP_ACC)) return } @@ -409,13 +435,13 @@ func (d *UdpAC) tcpTempAccessHandler(transactionId uint64, listener *net.TCPList remainingSize := packetSize - n n, err = conn.Read(pkt.Buf[n:packetSize]) if err != nil || n < remainingSize { - log.Error("ac(%s#%d)[HandleACOperations] failed to receive tcp message body from remote address %s (%v)", acId, transactionId, remoteAddrStr, err) + log.Error("[tcpTempAccessHandler] failed to receive tcp message body from remote address %s (%v)", remoteAddrStr, err) return } pkt.Content = pkt.Buf[:packetSize] - log.Trace("receive tcp access packet (%s -> %s): %+v", remoteAddrStr, localAddrStr, pkt.Content) - log.Info("ac(%s#%d)[HandleACOperations] receive tcp access message (%s -> %s)", acId, transactionId, remoteAddrStr, localAddrStr) + //log.Trace("[tcpTempAccessHandler]receive tcp access packet (%s -> %s): %+v", remoteAddrStr, localAddrStr, pkt.Content) + log.Info("[tcpTempAccessHandler] receive tcp access message (%s -> %s)", remoteAddrStr, localAddrStr) pd := &core.PacketData{ BasePacket: pkt, @@ -424,41 +450,33 @@ func (d *UdpAC) tcpTempAccessHandler(transactionId uint64, listener *net.TCPList DecryptedMsgCh: make(chan *core.PacketParserData), } - if !d.IsRunning() { - log.Error("ac(%s#%d)[HandleACOperations] PacketData channel closed or being closed, skip decrypting", acId, transactionId) + if !a.IsRunning() { + log.Error("[tcpTempAccessHandler] PacketData channel closed or being closed, skip decrypting") return } // start message decryption - d.device.RecvPacketToMsg(pd) + a.device.RecvPacketToMsg(pd) // waiting for message decryption accPpd := <-pd.DecryptedMsgCh close(pd.DecryptedMsgCh) if accPpd.Error != nil { - log.Error("ac(%s#%d)[HandleACOperations] failed to decrypt tcp access message: %v", acId, transactionId, accPpd.Error) + log.Error("[tcpTempAccessHandler] failed to decrypt tcp access message: %v", accPpd.Error) return } accMsg := &common.AgentAccessMsg{} err = json.Unmarshal(accPpd.BodyMessage, accMsg) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] failed to parse %s message: %v", acId, transactionId, core.HeaderTypeToString(accPpd.HeaderType), err) + log.Error("[tcpTempAccessHandler] failed to parse %s message: %v", core.HeaderTypeToString(accPpd.HeaderType), err) return } - remoteAgentUser := &AgentUser{ - UserId: accMsg.UserId, - DeviceId: accMsg.DeviceId, - OrganizationId: accMsg.OrganizationId, - } - - if d.VerifyAccessToken(remoteAgentUser, accMsg.ACToken) { + if a.VerifyAccessToken(accMsg.ACToken) != nil { remoteAddr, _ := net.ResolveTCPAddr(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) srcAddrIp := remoteAddr.IP.String() - dstAddrs := dopMsg.DestinationAddrs - openTimeSec := int(dopMsg.OpenTime) var ipType utils.IPTYPE if strings.Contains(dstAddrs[0].Ip, ":") { ipType = utils.IPV6 @@ -472,43 +490,41 @@ func (d *UdpAC) tcpTempAccessHandler(transactionId uint64, listener *net.TCPList ipHashStr = fmt.Sprintf("%s,1-65535,%s", srcAddrIp, dstAddr.Ip) } - _, err = d.ipset.Add(ipType, 1, openTimeSec, ipHashStr) + _, err = a.ipset.Add(ipType, 1, openTimeSec, ipHashStr) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, ipHashStr, err) + log.Error("[tcpTempAccessHandler] add ipset %s error: %v", ipHashStr, err) return } } } } -func (d *UdpAC) udpTempAccessHandler(transactionId uint64, conn *net.UDPConn, au *AgentUser, timeoutSec int, dopMsg *common.ServerACOpsMsg) { - defer d.wg.Done() - defer d.DeleteAccessToken(au) +func (a *UdpAC) udpTempAccessHandler(conn *net.UDPConn, timeoutSec int, dstAddrs []*common.NetAddress, openTimeSec int) { + defer a.wg.Done() defer conn.Close() - acId := d.config.ACId // listen to accept and handle only one incoming connection startTime := time.Now() deadlineTime := startTime.Add(time.Duration(timeoutSec) * time.Second) localAddrStr := conn.LocalAddr().String() err := conn.SetDeadline(deadlineTime) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] temporary port on %s failed to set udp conn timeout", acId, transactionId, localAddrStr) + log.Error("[udpTempAccessHandler] temporary port on %s failed to set udp conn timeout", localAddrStr) return } - pkt := d.device.AllocatePoolPacket() - defer d.device.ReleasePoolPacket(pkt) + pkt := a.device.AllocatePoolPacket() + defer a.device.ReleasePoolPacket(pkt) // monitor stop signals and quit connection earlier ctx, ctxCancel := context.WithDeadline(context.Background(), deadlineTime) defer ctxCancel() - go d.tempConnTerminator(conn, ctx) + go a.tempConnTerminator(conn, ctx) // udp recv, blocking until packet arrives or deadline reaches n, remoteAddr, err := conn.ReadFromUDP(pkt.Buf[:]) if err != nil || n < core.HeaderCommonSize { - log.Error("ac(%s#%d)[HandleACOperations] failed to receive udp packet (%v)", acId, transactionId, err) + log.Error("[udpTempAccessHandler] failed to receive udp packet (%v)", err) return } @@ -518,7 +534,7 @@ func (d *UdpAC) udpTempAccessHandler(transactionId uint64, conn *net.UDPConn, au // check type and payload size msgType, msgSize := pkt.HeaderTypeAndSize() if msgType != core.NHP_ACC { - log.Error("ac(%s#%d)[HandleACOperations] message type is not %s, close connection", acId, transactionId, core.HeaderTypeToString(core.NHP_ACC)) + log.Error("[udpTempAccessHandler] message type is not %s, close connection", core.HeaderTypeToString(core.NHP_ACC)) return } @@ -531,12 +547,12 @@ func (d *UdpAC) udpTempAccessHandler(transactionId uint64, conn *net.UDPConn, au } if n != packetSize { - log.Error("ac(%s#%d)[HandleACOperations] udp packet size incorrect from remote address %s", acId, transactionId, remoteAddrStr) + log.Error("[udpTempAccessHandler] udp packet size incorrect from remote address %s", remoteAddrStr) return } log.Trace("receive udp access packet (%s -> %s): %+v", remoteAddrStr, localAddrStr, pkt.Content) - log.Info("ac(%s#%d)[HandleACOperations] receive udp access message (%s -> %s)", acId, transactionId, remoteAddrStr, localAddrStr) + log.Info("[udpTempAccessHandler] receive udp access message (%s -> %s)", remoteAddrStr, localAddrStr) pd := &core.PacketData{ BasePacket: pkt, @@ -545,40 +561,32 @@ func (d *UdpAC) udpTempAccessHandler(transactionId uint64, conn *net.UDPConn, au DecryptedMsgCh: make(chan *core.PacketParserData), } - if !d.IsRunning() { - log.Error("ac(%s#%d)[HandleACOperations] PacketData channel closed or being closed, skip decrypting", acId, transactionId) + if !a.IsRunning() { + log.Error("[udpTempAccessHandler] PacketData channel closed or being closed, skip decrypting") return } // start packet decryption - d.device.RecvPacketToMsg(pd) + a.device.RecvPacketToMsg(pd) // waiting for packet decryption accPpd := <-pd.DecryptedMsgCh close(pd.DecryptedMsgCh) if accPpd.Error != nil { - log.Error("ac(%s#%d)[HandleACOperations] failed to decrypt udp access message: %v", acId, transactionId, accPpd.Error) + log.Error("[udpTempAccessHandler] failed to decrypt udp access message: %v", accPpd.Error) return } accMsg := &common.AgentAccessMsg{} err = json.Unmarshal(accPpd.BodyMessage, accMsg) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] failed to parse %s message: %v", acId, transactionId, core.HeaderTypeToString(accPpd.HeaderType), err) + log.Error("[udpTempAccessHandler] failed to parse %s message: %v", core.HeaderTypeToString(accPpd.HeaderType), err) return } - remoteAgentUser := &AgentUser{ - UserId: accMsg.UserId, - DeviceId: accMsg.DeviceId, - OrganizationId: accMsg.OrganizationId, - } - - if d.VerifyAccessToken(remoteAgentUser, accMsg.ACToken) { + if a.VerifyAccessToken(accMsg.ACToken) != nil { srcAddrIp := remoteAddr.IP.String() - dstAddrs := dopMsg.DestinationAddrs - openTimeSec := int(dopMsg.OpenTime) var ipType utils.IPTYPE if strings.Contains(dstAddrs[0].Ip, ":") { ipType = utils.IPV6 @@ -593,9 +601,9 @@ func (d *UdpAC) udpTempAccessHandler(transactionId uint64, conn *net.UDPConn, au ipHashStr = fmt.Sprintf("%s,udp:1-65535,%s", srcAddrIp, dstAddr.Ip) } - _, err = d.ipset.Add(ipType, 1, openTimeSec, ipHashStr) + _, err = a.ipset.Add(ipType, 1, openTimeSec, ipHashStr) if err != nil { - log.Error("ac(%s#%d)[HandleACOperations] add ipset %s error: %v", acId, transactionId, ipHashStr, err) + log.Error("[udpTempAccessHandler] add ipset %s error: %v", ipHashStr, err) return } } @@ -603,15 +611,15 @@ func (d *UdpAC) udpTempAccessHandler(transactionId uint64, conn *net.UDPConn, au // for ping if dstAddr.Port == 0 && (len(dstAddr.Protocol) == 0 || dstAddr.Protocol == "any") { ipHashStr := fmt.Sprintf("%s,icmp:8/0,%s", remoteAddr.IP.String(), dstAddr.Ip) - d.ipset.Add(ipType, 1, openTimeSec, ipHashStr) + a.ipset.Add(ipType, 1, openTimeSec, ipHashStr) } } } } -func (d *UdpAC) tempConnTerminator(conn net.Conn, ctx context.Context) { +func (a *UdpAC) tempConnTerminator(conn net.Conn, ctx context.Context) { select { - case <-d.signals.stop: + case <-a.signals.stop: conn.Close() return diff --git a/ac/tokenstore.go b/ac/tokenstore.go new file mode 100644 index 00000000..79e61514 --- /dev/null +++ b/ac/tokenstore.go @@ -0,0 +1,104 @@ +package ac + +import ( + "encoding/base64" + "encoding/binary" + "time" + + "github.com/emmansun/gmsm/sm3" + + "github.com/OpenNHP/opennhp/common" + "github.com/OpenNHP/opennhp/log" +) + +type AgentUser struct { + UserId string + DeviceId string + OrganizationId string + AuthServiceId string +} + +type AccessEntry struct { + AgentUser *AgentUser + SrcAddrs []*common.NetAddress + DstAddrs []*common.NetAddress + OpenTime int + ExpireTime time.Time +} + +type TokenAccessMap = map[string]*AccessEntry // access token mapped into user and access information +type TokenStore = map[string]TokenAccessMap // upper layer of tokens, indexed by first two characters + +func (a *UdpAC) GenerateAccessToken(entry *AccessEntry) string { + var tsBytes [8]byte + currTime := time.Now().UnixNano() + + hash := sm3.New() + binary.BigEndian.PutUint64(tsBytes[:], uint64(currTime)) + au := entry.AgentUser + hash.Write([]byte(a.config.ACId + au.UserId + au.DeviceId + au.OrganizationId + au.AuthServiceId)) + hash.Write(tsBytes[:]) + token := base64.StdEncoding.EncodeToString(hash.Sum(nil)) + hash.Reset() + + a.TokenStoreMutex.Lock() + defer a.TokenStoreMutex.Unlock() + + entry.ExpireTime = time.Now().Add(time.Duration(entry.OpenTime) * time.Second) + tokenMap, found := a.tokenStore[token[0:1]] + if found { + tokenMap[token] = entry + } else { + tokenMap := make(TokenAccessMap) + tokenMap[token] = entry + a.tokenStore[token[0:1]] = tokenMap + } + + return token +} + +func (a *UdpAC) VerifyAccessToken(token string) *AccessEntry { + a.TokenStoreMutex.Lock() + defer a.TokenStoreMutex.Unlock() + + tokenMap, found := a.tokenStore[token[0:1]] + if found { + entry, found := tokenMap[token] + if found { + return entry + } + } + + return nil +} + +func (a *UdpAC) tokenStoreRefreshRoutine() { + defer a.wg.Done() + defer log.Info("tokenStoreRefreshRoutine stopped") + + log.Info("tokenStoreRefreshRoutine started") + + for { + select { + case <-a.signals.stop: + return + + case <-time.After(TokenStoreRefreshInterval * time.Second): + a.TokenStoreMutex.Lock() + defer a.TokenStoreMutex.Unlock() + + now := time.Now() + for head, tokenMap := range a.tokenStore { + for token, entry := range tokenMap { + if now.After(entry.ExpireTime) { + log.Info("[TokenStore] token %s expired", token) + delete(tokenMap, token) + } + } + if len(tokenMap) == 0 { + delete(a.tokenStore, head) + } + } + } + } +} diff --git a/ac/udpac.go b/ac/udpac.go index 5ce5097c..c5706a15 100644 --- a/ac/udpac.go +++ b/ac/udpac.go @@ -1,815 +1,740 @@ -package ac - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "hash" - "net" - "path/filepath" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/OpenNHP/opennhp/common" - "github.com/OpenNHP/opennhp/core" - "github.com/OpenNHP/opennhp/log" - "github.com/OpenNHP/opennhp/utils" - "github.com/OpenNHP/opennhp/version" -) - -var ( - ExeDirPath string -) - -type AgentUser struct { - UserId string - DeviceId string - OrganizationId string - hash hash.Hash -} - -func (au *AgentUser) Hash() string { - au.hash = core.NewHash(core.HASH_SM3) - au.hash.Write([]byte(au.UserId)) - au.hash.Write([]byte(au.DeviceId)) - au.hash.Write([]byte(au.OrganizationId)) - // do not include Agent's PublicKey in calculating hash, because it may vary between Curve25519 and SM2 - sum := au.hash.Sum(nil) - return string(sum) -} - -type AgentUserCodeMap = map[string]*map[string]string // agent hash string first letter > agent hash string > token - -type UdpAC struct { - config *Config - iptables *utils.IPTables - ipset *utils.IPSet - - stats struct { - totalRecvBytes uint64 - totalSendBytes uint64 - } - - log *log.Logger - - remoteConnectionMutex sync.Mutex - remoteConnectionMap map[string]*UdpConn // indexed by remote UDP address - - serverPeerMutex sync.Mutex - serverPeerMap map[string]*core.UdpPeer // indexed by server's public key - - AgentUserTokenMutex sync.Mutex - agentUserCodeMap AgentUserCodeMap - - device *core.Device - wg sync.WaitGroup - running atomic.Bool - - signals struct { - stop chan struct{} - serverMapUpdated chan struct{} - } - - recvMsgCh <-chan *core.PacketParserData - sendMsgCh chan *core.MsgData -} - -type UdpConn struct { - ConnData *core.ConnectionData - netConn *net.UDPConn - connected atomic.Bool - externalAddr string -} - -func (c *UdpConn) Close() { - c.netConn.Close() - c.ConnData.Close() -} - -/* -dirPath: the path of app or shared library entry point -logLevel: 0: silent, 1: error, 2: info, 3: debug, 4: verbose -*/ -func (d *UdpAC) Start(dirPath string, logLevel int) (err error) { - common.ExeDirPath = dirPath - ExeDirPath = dirPath - // init logger - d.log = log.NewLogger("NHP-AC", logLevel, filepath.Join(ExeDirPath, "logs"), "ac") - log.SetGlobalLogger(d.log) - - log.Info("=========================================================") - log.Info("=== NHP-AC %s started ===", version.Version) - log.Info("=== REVISION %s ===", version.CommitId) - log.Info("=== RELEASE %s ===", version.BuildTime) - log.Info("=========================================================") - - // init config - err = d.loadBaseConfig() - if err != nil { - return err - } - - d.iptables, err = utils.NewIPTables() - if err != nil { - log.Error("iptables command not found") - return - } - - d.ipset, err = utils.NewIPSet(false) - if err != nil { - log.Error("ipset command not found") - return - } - - prk, err := base64.StdEncoding.DecodeString(d.config.PrivateKeyBase64) - if err != nil { - log.Error("private key parse error %v\n", err) - return fmt.Errorf("private key parse error %v", err) - } - - d.device = core.NewDevice(core.NHP_AC, prk, nil) - if d.device == nil { - log.Critical("failed to create device %v\n", err) - return fmt.Errorf("failed to create device %v", err) - } - - d.remoteConnectionMap = make(map[string]*UdpConn) - d.serverPeerMap = make(map[string]*core.UdpPeer) - d.agentUserCodeMap = make(AgentUserCodeMap) - - // load peers - d.loadPeers() - - d.signals.stop = make(chan struct{}) - d.signals.serverMapUpdated = make(chan struct{}, 1) - - d.recvMsgCh = d.device.DecryptedMsgQueue - d.sendMsgCh = make(chan *core.MsgData, core.SendQueueSize) - - // start device routines - d.device.Start() - - // start ac routines - d.wg.Add(3) - go d.sendMessageRoutine() - go d.recvMessageRoutine() - go d.maintainServerConnectionRoutine() - - d.running.Store(true) - return nil -} - -func (d *UdpAC) Stop() { - d.running.Store(false) - close(d.signals.stop) - - d.device.Stop() - d.StopConfigWatch() - d.wg.Wait() - close(d.sendMsgCh) - close(d.signals.serverMapUpdated) - - log.Info("==========================") - log.Info("=== NHP-AC stopped ===") - log.Info("==========================") - d.log.Close() -} - -func (d *UdpAC) IsRunning() bool { - return d.running.Load() -} - -func (d *UdpAC) newConnection(addr *net.UDPAddr) (conn *UdpConn) { - conn = &UdpConn{} - var err error - // unlike tcp, udp dial is fast (just socket bind), so no need to run in a thread - conn.netConn, err = net.DialUDP("udp", nil, addr) - if err != nil { - log.Error("could not connect to remote addr %s", addr.String()) - return nil - } - - // retrieve local port - laddr := conn.netConn.LocalAddr() - localAddr, err := net.ResolveUDPAddr(laddr.Network(), laddr.String()) - if err != nil { - log.Error("resolve local UDPAddr error %v\n", err) - return nil - } - - log.Info("Dial up new UDP connection from %s to %s", localAddr.String(), addr.String()) - - conn.ConnData = &core.ConnectionData{ - Device: d.device, - CookieStore: &core.CookieStore{}, - RemoteTransactionMap: make(map[uint64]*core.RemoteTransaction), - LocalAddr: localAddr, - RemoteAddr: addr, - TimeoutMs: DefaultConnectionTimeoutMs, - SendQueue: make(chan *core.Packet, PacketQueueSizePerConnection), - RecvQueue: make(chan *core.Packet, PacketQueueSizePerConnection), - BlockSignal: make(chan struct{}), - SetTimeoutSignal: make(chan struct{}), - StopSignal: make(chan struct{}), - } - - // start connection receive routine - conn.ConnData.Add(1) - go d.recvPacketRoutine(conn) - - return conn -} - -func (d *UdpAC) sendMessageRoutine() { - defer d.wg.Done() - defer log.Info("sendMessageRoutine stopped") - - log.Info("sendMessageRoutine started") - - for { - select { - case <-d.signals.stop: - return - - case md, ok := <-d.sendMsgCh: - if !ok { - return - } - if md == nil || md.RemoteAddr == nil { - log.Warning("Invalid initiator session starter") - continue - } - - addrStr := md.RemoteAddr.String() - - d.remoteConnectionMutex.Lock() - conn, found := d.remoteConnectionMap[addrStr] - d.remoteConnectionMutex.Unlock() - - if found { - md.ConnData = conn.ConnData - } else { - conn = d.newConnection(md.RemoteAddr) - if conn == nil { - log.Error("Failed to dial to remote address: %s", addrStr) - continue - } - - d.remoteConnectionMutex.Lock() - d.remoteConnectionMap[addrStr] = conn - d.remoteConnectionMutex.Unlock() - - md.ConnData = conn.ConnData - - // launch connection routine - d.wg.Add(1) - go d.connectionRoutine(conn) - } - - d.device.SendMsgToPacket(md) - } - } -} - -func (d *UdpAC) SendPacket(pkt *core.Packet, conn *UdpConn) (n int, err error) { - defer func() { - atomic.AddUint64(&d.stats.totalSendBytes, uint64(n)) - atomic.StoreInt64(&conn.ConnData.LastLocalSendTime, time.Now().UnixNano()) - - if !pkt.KeepAfterSend { - d.device.ReleasePoolPacket(pkt) - } - }() - - pktType := core.HeaderTypeToString(pkt.HeaderType) - //log.Debug("Send [%s] packet (%s -> %s): %+v", pktType, conn.ConnData.LocalAddr.String(), conn.ConnData.RemoteAddr.String(), pkt.Content) - log.Info("Send [%s] packet (%s -> %s), %d bytes", pktType, conn.ConnData.LocalAddr.String(), conn.ConnData.RemoteAddr.String(), len(pkt.Content)) - log.Evaluate("Send [%s] packet (%s -> %s, %d bytes)", pktType, conn.ConnData.LocalAddr.String(), conn.ConnData.RemoteAddr.String(), len(pkt.Content)) - return conn.netConn.Write(pkt.Content) -} - -func (d *UdpAC) recvPacketRoutine(conn *UdpConn) { - addrStr := conn.ConnData.RemoteAddr.String() - - defer conn.ConnData.Done() - defer log.Debug("recvPacketRoutine for %s stopped", addrStr) - - log.Debug("recvPacketRoutine for %s started", addrStr) - - for { - select { - case <-conn.ConnData.StopSignal: - return - - default: - } - - // udp recv, blocking until packet arrives or netConn.Close() - pkt := d.device.AllocatePoolPacket() - n, err := conn.netConn.Read(pkt.Buf[:]) - if err != nil { - d.device.ReleasePoolPacket(pkt) - if n == 0 { - // udp connection closed, it is not an error - return - } - log.Error("Failed to receive from remote address %s (%v)", addrStr, err) - continue - } - - // add total recv bytes - atomic.AddUint64(&d.stats.totalRecvBytes, uint64(n)) - - // check minimal length - if n < core.HeaderSize { - d.device.ReleasePoolPacket(pkt) - log.Error("Received UDP packet from %s is too short, discard", addrStr) - continue - } - - pkt.Content = pkt.Buf[:n] - //log.Trace("receive udp packet (%s -> %s): %+v", conn.ConnData.RemoteAddr.String(), conn.ConnData.LocalAddr.String(), pkt.Content) - - typ, _, err := d.device.RecvPrecheck(pkt) - msgType := core.HeaderTypeToString(typ) - log.Info("Receive [%s] packet (%s -> %s), %d bytes", msgType, addrStr, conn.ConnData.LocalAddr.String(), n) - log.Evaluate("Receive [%s] packet (%s -> %s), %d bytes", msgType, addrStr, conn.ConnData.LocalAddr.String(), n) - if err != nil { - d.device.ReleasePoolPacket(pkt) - log.Warning("Receive [%s] packet (%s -> %s), precheck error: %v", msgType, addrStr, conn.ConnData.LocalAddr.String(), err) - log.Evaluate("Receive [%s] packet (%s -> %s) precheck error: %v", msgType, addrStr, conn.ConnData.LocalAddr.String(), err) - continue - } - - atomic.StoreInt64(&conn.ConnData.LastLocalRecvTime, time.Now().UnixNano()) - - conn.ConnData.ForwardInboundPacket(pkt) - } -} - -func (d *UdpAC) connectionRoutine(conn *UdpConn) { - addrStr := conn.ConnData.RemoteAddr.String() - - defer d.wg.Done() - defer log.Debug("Connection routine: %s stopped", addrStr) - - log.Debug("Connection routine: %s started", addrStr) - - // stop receiving packets and clean up - defer func() { - d.remoteConnectionMutex.Lock() - delete(d.remoteConnectionMap, addrStr) - d.remoteConnectionMutex.Unlock() - - conn.Close() - }() - - for { - select { - case <-d.signals.stop: - return - - case <-conn.ConnData.SetTimeoutSignal: - if conn.ConnData.TimeoutMs <= 0 { - log.Debug("Connection routine closed immediately") - return - } - - case <-time.After(time.Duration(conn.ConnData.TimeoutMs) * time.Millisecond): - // timeout, quit routine - log.Debug("Connection routine idle timeout") - return - - case pkt, ok := <-conn.ConnData.SendQueue: - if !ok { - return - } - if pkt == nil { - continue - } - d.SendPacket(pkt, conn) - - case pkt, ok := <-conn.ConnData.RecvQueue: - if !ok { - return - } - if pkt == nil { - continue - } - log.Debug("Received udp packet len [%d] from addr: %s\n", len(pkt.Content), addrStr) - - if pkt.HeaderType == core.NHP_KPL { - d.device.ReleasePoolPacket(pkt) - log.Info("Receive [NHP_KPL] message (%s -> %s)", addrStr, conn.ConnData.LocalAddr.String()) - continue - } - - if d.device.IsTransactionResponse(pkt.HeaderType) { - // forward to a specific transaction - transactionId := pkt.Counter() - transaction := d.device.FindLocalTransaction(transactionId) - if transaction != nil { - transaction.NextPacketCh <- pkt - continue - } - } - - pd := &core.PacketData{ - BasePacket: pkt, - ConnData: conn.ConnData, - InitTime: atomic.LoadInt64(&conn.ConnData.LastLocalRecvTime), - } - // generic receive - d.device.RecvPacketToMsg(pd) - - case <-conn.ConnData.BlockSignal: - log.Critical("blocking address %s", addrStr) - return - } - } -} - -func (d *UdpAC) recvMessageRoutine() { - defer d.wg.Done() - defer log.Info("recvMessageRoutine stopped") - - log.Info("recvMessageRoutine started") - - for { - select { - case <-d.signals.stop: - return - - case ppd, ok := <-d.recvMsgCh: - if !ok { - return - } - if ppd == nil { - continue - } - - switch ppd.HeaderType { - case core.NHP_AOP: - // deal with NHP_AOP message - go d.HandleACOperations(ppd) - } - } - } -} - -// keep interaction between ac and server in certain time interval to keep outwards ip path active -func (d *UdpAC) maintainServerConnectionRoutine() { - defer d.wg.Done() - defer log.Info("maintainServerConnectionRoutine stopped") - - log.Info("maintainServerConnectionRoutine started") - - // reset iptables before exiting - defer d.iptables.ResetAllInput() - - var discoveryRoutineWg sync.WaitGroup - defer discoveryRoutineWg.Wait() - - for { - // make a local copy of servers then iterate because next operations are time consuming (too long to use locked iteration) - d.serverPeerMutex.Lock() - var serverCount int32 = int32(len(d.serverPeerMap)) - discoveryQuitArr := make([]chan struct{}, 0, serverCount) - discoveryFailStatusArr := make([]*int32, 0, serverCount) - - for _, server := range d.serverPeerMap { - // launch discovery routine for each server - fail := new(int32) - discoveryFailStatusArr = append(discoveryFailStatusArr, fail) - quit := make(chan struct{}) - discoveryQuitArr = append(discoveryQuitArr, quit) - - discoveryRoutineWg.Add(1) - go d.serverDiscovery(server, &discoveryRoutineWg, fail, quit) - } - d.serverPeerMutex.Unlock() - - // check whether all server discovery failed. - // If so, open all blocked input - quitCheck := make(chan struct{}) - discoveryQuitArr = append(discoveryQuitArr, quitCheck) - discoveryRoutineWg.Add(1) - go func() { - defer discoveryRoutineWg.Done() - - for { - select { - case <-d.signals.stop: - return - case <-quitCheck: - return - case <-time.After(MinialServerDiscoveryInterval * time.Second): - var totalFail int32 - for _, status := range discoveryFailStatusArr { - totalFail += atomic.LoadInt32(status) - } - - if totalFail < int32(len(discoveryFailStatusArr)) { - d.iptables.ResetAllInput() - } else { - d.iptables.AcceptAllInput() - } - } - } - }() - - select { - case <-d.signals.stop: - return - case _, ok := <-d.signals.serverMapUpdated: - if !ok { - return - } - // stop all current discovery routines - for _, q := range discoveryQuitArr { - close(q) - } - // continue and restart with new server discovery cycle - } - } -} - -func (d *UdpAC) serverDiscovery(server *core.UdpPeer, discoveryRoutineWg *sync.WaitGroup, serverFailCount *int32, quit <-chan struct{}) { - defer discoveryRoutineWg.Done() - - acId := d.config.ACId - serverAddr := server.HostOrAddr() - server, sendAddr := d.ResolvePeer(server) - if sendAddr == nil { - log.Error("Cannot connect to nil server address") - return - } - - addrStr := sendAddr.String() - - defer log.Info("server discovery sub-routine at %s stopped", serverAddr) - log.Info("server discovery sub-routine at %s started", serverAddr) - - var failCount int - - for { - var lastSendTime int64 - var lastRecvTime int64 - var connected bool - - // find whether connection is already connected - d.remoteConnectionMutex.Lock() - conn, found := d.remoteConnectionMap[addrStr] - d.remoteConnectionMutex.Unlock() - - if found { - // connection based timing - lastSendTime = atomic.LoadInt64(&conn.ConnData.LastLocalSendTime) - lastRecvTime = atomic.LoadInt64(&conn.ConnData.LastLocalRecvTime) - connected = conn.connected.Load() - } else { - // peer based timing - conn = nil - lastSendTime = server.LastSendTime() - lastRecvTime = server.LastRecvTime() - } - - currTime := time.Now().UnixNano() - peerPbk := server.PublicKey() - - // when a server is not connected, try to connect in every ACLocalTransactionResponseTimeoutMs - // when a server is connected when ServerConnectionInterval is reached since last receive, try resend NHP_AOL for maintaining server connection - if !connected || (currTime-lastRecvTime) > int64(ReportToServerInterval*time.Second) { - // send NHP_AOL message to server - aolMsg := &common.ACOnlineMsg{ - ACId: acId, - AuthServiceId: d.config.AuthServiceId, - ResourceIds: d.config.ResourceIds, - } - aolBytes, _ := json.Marshal(aolMsg) - - aolMd := &core.MsgData{ - RemoteAddr: sendAddr.(*net.UDPAddr), - HeaderType: core.NHP_AOL, - TransactionId: d.device.NextCounterIndex(), - Compress: true, - PeerPk: peerPbk, - Message: aolBytes, - ResponseMsgCh: make(chan *core.PacketParserData), - } - - if !d.IsRunning() { - log.Error("ac(%s#%d)[ACOnline] MsgData channel closed or being closed, skip sending", acId, aolMd.TransactionId) - return - } - - d.sendMsgCh <- aolMd // create new connection - server.UpdateSend(currTime) - - // block until transaction completes or timeouts - ppd := <-aolMd.ResponseMsgCh - close(aolMd.ResponseMsgCh) - - var err error - func() { - defer func() { - if err != nil { - if conn != nil { - conn.connected.Store(false) - } - - failCount += 1 - if failCount%ServerDiscoveryRetryBeforeFail == 0 { - atomic.StoreInt32(serverFailCount, 1) - // remove failed connection - d.remoteConnectionMutex.Lock() - conn = d.remoteConnectionMap[addrStr] - if conn != nil { - log.Info("server discovery failed, close local connection: %s", conn.ConnData.LocalAddr.String()) - delete(d.remoteConnectionMap, addrStr) - } - d.remoteConnectionMutex.Unlock() - conn.Close() - } - log.Error("ac(%s#%d)[ACOnline] reporting to server %s failed", acId, aolMd.TransactionId, addrStr) - } - - }() - - if ppd.Error != nil { - log.Error("ac(%s#%d)[ACOnline] failed to receive response from server %s: %v", acId, aolMd.TransactionId, addrStr, ppd.Error) - err = ppd.Error - return - } - - if ppd.HeaderType != core.NHP_AAK { - log.Error("ac(%s#%d)[ACOnline] response from server %s has wrong type: %s", acId, aolMd.TransactionId, addrStr, core.HeaderTypeToString(ppd.HeaderType)) - err = common.ErrTransactionRepliedWithWrongType - return - } - - aakMsg := &common.ServerACAckMsg{} - err = json.Unmarshal(ppd.BodyMessage, aakMsg) - if err != nil { - log.Error("ac(%s#%d)[HandleACAck] failed to parse %s message: %v", acId, ppd.SenderTrxId, core.HeaderTypeToString(ppd.HeaderType), err) - return - } - - // server discovery succeeded - failCount = 0 - atomic.StoreInt32(serverFailCount, 0) - d.remoteConnectionMutex.Lock() - conn = d.remoteConnectionMap[addrStr] // conn must be available at this point - conn.connected.Store(true) - conn.externalAddr = aakMsg.ACAddr - d.remoteConnectionMutex.Unlock() - log.Info("ac(%s#%d)[ACOnline] succeed. ac external address is %s, replied by server %s", acId, aolMd.TransactionId, aakMsg.ACAddr, addrStr) - }() - - } else if connected { - if (currTime - lastSendTime) > int64(ServerKeepaliveInterval*time.Second) { - // send NHP_KPL to server if no send happens within ServerKeepaliveInterval - md := &core.MsgData{ - RemoteAddr: sendAddr.(*net.UDPAddr), - HeaderType: core.NHP_KPL, - //PeerPk: peerPbk, // pubkey not needed - TransactionId: d.device.NextCounterIndex(), - } - - d.sendMsgCh <- md // send NHP_KPL to server via existing connection - server.UpdateSend(currTime) - } - } - - select { - case <-d.signals.stop: - return - case <-quit: - return - case <-time.After(MinialServerDiscoveryInterval * time.Second): - // wait for ServerConnectionDiscoveryInterval - } - } -} - -func (d *UdpAC) AddServerPeer(server *core.UdpPeer) { - if server.DeviceType() == core.NHP_SERVER { - d.device.AddPeer(server) - - d.serverPeerMutex.Lock() - d.serverPeerMap[server.PublicKeyBase64()] = server - d.serverPeerMutex.Unlock() - - // renew server connection cycle - if len(d.signals.serverMapUpdated) == 0 { - d.signals.serverMapUpdated <- struct{}{} - } - } -} - -func (d *UdpAC) RemoveServerPeer(serverKey string) { - d.serverPeerMutex.Lock() - beforeSize := len(d.serverPeerMap) - delete(d.serverPeerMap, serverKey) - afterSize := len(d.serverPeerMap) - d.serverPeerMutex.Unlock() - - if beforeSize != afterSize { - // renew server connection cycle - if len(d.signals.serverMapUpdated) == 0 { - d.signals.serverMapUpdated <- struct{}{} - } - } -} - -func (d *UdpAC) GenerateAccessToken(au *AgentUser) string { - hashStr := au.Hash() - timeStr := strconv.FormatInt(time.Now().UnixNano(), 10) - au.hash.Write([]byte(timeStr)) - token := base64.StdEncoding.EncodeToString(au.hash.Sum(nil)) - - d.AgentUserTokenMutex.Lock() - defer d.AgentUserTokenMutex.Unlock() - - tokenMap, found := d.agentUserCodeMap[hashStr[0:1]] - if found { - (*tokenMap)[hashStr] = token - } else { - tokenMap = &map[string]string{hashStr: token} - d.agentUserCodeMap[hashStr[0:1]] = tokenMap - } - - // log.Debug("user %+v, hash: %s", au, hashStr) - // log.Debug("agentUserCodeMap: %+v", d.agentUserCodeMap) - // log.Debug("tokenMap: %+v", d.agentUserCodeMap[hashStr[0:1]]) - return token -} - -func (d *UdpAC) VerifyAccessToken(au *AgentUser, token string) bool { - hashStr := au.Hash() - - d.AgentUserTokenMutex.Lock() - defer d.AgentUserTokenMutex.Unlock() - - // log.Debug("verify access token: %s", token) - // log.Debug("user %+v, hash: %s", au, hashStr) - // log.Debug("agentUserCodeMap: %+v", d.agentUserCodeMap) - // log.Debug("tokenMap: %+v", d.agentUserCodeMap[hashStr[0:1]]) - - tokenMap, found := d.agentUserCodeMap[hashStr[0:1]] - if found { - foundToken, found := (*tokenMap)[hashStr] - if found { - return token == foundToken - } - } - - return false -} - -func (d *UdpAC) DeleteAccessToken(au *AgentUser) { - hashStr := au.Hash() - - d.AgentUserTokenMutex.Lock() - defer d.AgentUserTokenMutex.Unlock() - - tokenMap, found := d.agentUserCodeMap[hashStr[0:1]] - if found { - delete(*tokenMap, hashStr) - if len(*tokenMap) == 0 { - delete(d.agentUserCodeMap, hashStr[0:1]) - } - } -} - -// if the server uses hostname as destination, find the correct peer with the actual IP address -func (d *UdpAC) ResolvePeer(peer *core.UdpPeer) (*core.UdpPeer, net.Addr) { - addr := peer.SendAddr() - if addr == nil { - return peer, nil - } - - if len(peer.Hostname) == 0 { - // peer with fixed ip, no change - return peer, addr - } - - actualIp := peer.ResolvedIp() - if peer.Ip == actualIp { - // peer with the correct resolved address, no change - return peer, addr - } - - d.serverPeerMutex.Lock() - defer d.serverPeerMutex.Unlock() - for _, p := range d.serverPeerMap { - if p.Ip == actualIp { - p.CopyResolveStatus(peer) - return p, addr - } - } - - return peer, addr -} +package ac + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net" + "path/filepath" + "sync" + "sync/atomic" + "time" + + "github.com/OpenNHP/opennhp/common" + "github.com/OpenNHP/opennhp/core" + "github.com/OpenNHP/opennhp/log" + "github.com/OpenNHP/opennhp/utils" + "github.com/OpenNHP/opennhp/version" +) + +var ( + ExeDirPath string +) + +type UdpAC struct { + config *Config + httpConfig *HttpConfig + iptables *utils.IPTables + ipset *utils.IPSet + + stats struct { + totalRecvBytes uint64 + totalSendBytes uint64 + } + + log *log.Logger + + remoteConnectionMutex sync.Mutex + remoteConnectionMap map[string]*UdpConn // indexed by remote UDP address + + serverPeerMutex sync.Mutex + serverPeerMap map[string]*core.UdpPeer // indexed by server's public key + + TokenStoreMutex sync.Mutex + tokenStore TokenStore + + device *core.Device + httpServer *HttpAC + wg sync.WaitGroup + running atomic.Bool + + signals struct { + stop chan struct{} + serverMapUpdated chan struct{} + } + + recvMsgCh <-chan *core.PacketParserData + sendMsgCh chan *core.MsgData +} + +type UdpConn struct { + ConnData *core.ConnectionData + netConn *net.UDPConn + connected atomic.Bool + externalAddr string +} + +func (c *UdpConn) Close() { + c.netConn.Close() + c.ConnData.Close() +} + +/* +dirPath: the path of app or shared library entry point +logLevel: 0: silent, 1: error, 2: info, 3: debug, 4: verbose +*/ +func (a *UdpAC) Start(dirPath string, logLevel int) (err error) { + common.ExeDirPath = dirPath + ExeDirPath = dirPath + // init logger + a.log = log.NewLogger("NHP-AC", logLevel, filepath.Join(ExeDirPath, "logs"), "ac") + log.SetGlobalLogger(a.log) + + log.Info("=========================================================") + log.Info("=== NHP-AC %s started ===", version.Version) + log.Info("=== REVISION %s ===", version.CommitId) + log.Info("=== RELEASE %s ===", version.BuildTime) + log.Info("=========================================================") + + // init config + err = a.loadBaseConfig() + if err != nil { + return err + } + + // load http config and turn on http server if needed + a.loadHttpConfig() + + a.iptables, err = utils.NewIPTables() + if err != nil { + log.Error("iptables command not found") + return + } + + a.ipset, err = utils.NewIPSet(false) + if err != nil { + log.Error("ipset command not found") + return + } + + prk, err := base64.StdEncoding.DecodeString(a.config.PrivateKeyBase64) + if err != nil { + log.Error("private key parse error %v\n", err) + return fmt.Errorf("private key parse error %v", err) + } + + a.device = core.NewDevice(core.NHP_AC, prk, nil) + if a.device == nil { + log.Critical("failed to create device %v\n", err) + return fmt.Errorf("failed to create device %v", err) + } + + a.remoteConnectionMap = make(map[string]*UdpConn) + a.serverPeerMap = make(map[string]*core.UdpPeer) + a.tokenStore = make(TokenStore) + + // load peers + a.loadPeers() + + a.signals.stop = make(chan struct{}) + a.signals.serverMapUpdated = make(chan struct{}, 1) + + a.recvMsgCh = a.device.DecryptedMsgQueue + a.sendMsgCh = make(chan *core.MsgData, core.SendQueueSize) + + // start device routines + a.device.Start() + + // start ac routines + a.wg.Add(4) + go a.tokenStoreRefreshRoutine() + go a.sendMessageRoutine() + go a.recvMessageRoutine() + go a.maintainServerConnectionRoutine() + + a.running.Store(true) + return nil +} + +func (ac *UdpAC) Stop() { + ac.running.Store(false) + close(ac.signals.stop) + + ac.device.Stop() + ac.StopConfigWatch() + ac.wg.Wait() + close(ac.sendMsgCh) + close(ac.signals.serverMapUpdated) + + log.Info("==========================") + log.Info("=== NHP-AC stopped ===") + log.Info("==========================") + ac.log.Close() +} + +func (a *UdpAC) IsRunning() bool { + return a.running.Load() +} + +func (a *UdpAC) newConnection(addr *net.UDPAddr) (conn *UdpConn) { + conn = &UdpConn{} + var err error + // unlike tcp, udp dial is fast (just socket bind), so no need to run in a thread + conn.netConn, err = net.DialUDP("udp", nil, addr) + if err != nil { + log.Error("could not connect to remote addr %s", addr.String()) + return nil + } + + // retrieve local port + laddr := conn.netConn.LocalAddr() + localAddr, err := net.ResolveUDPAddr(laddr.Network(), laddr.String()) + if err != nil { + log.Error("resolve local UDPAddr error %v\n", err) + return nil + } + + log.Info("Dial up new UDP connection from %s to %s", localAddr.String(), addr.String()) + + conn.ConnData = &core.ConnectionData{ + Device: a.device, + CookieStore: &core.CookieStore{}, + RemoteTransactionMap: make(map[uint64]*core.RemoteTransaction), + LocalAddr: localAddr, + RemoteAddr: addr, + TimeoutMs: DefaultConnectionTimeoutMs, + SendQueue: make(chan *core.Packet, PacketQueueSizePerConnection), + RecvQueue: make(chan *core.Packet, PacketQueueSizePerConnection), + BlockSignal: make(chan struct{}), + SetTimeoutSignal: make(chan struct{}), + StopSignal: make(chan struct{}), + } + + // start connection receive routine + conn.ConnData.Add(1) + go a.recvPacketRoutine(conn) + + return conn +} + +func (a *UdpAC) sendMessageRoutine() { + defer a.wg.Done() + defer log.Info("sendMessageRoutine stopped") + + log.Info("sendMessageRoutine started") + + for { + select { + case <-a.signals.stop: + return + + case md, ok := <-a.sendMsgCh: + if !ok { + return + } + if md == nil || md.RemoteAddr == nil { + log.Warning("Invalid initiator session starter") + continue + } + + addrStr := md.RemoteAddr.String() + + a.remoteConnectionMutex.Lock() + conn, found := a.remoteConnectionMap[addrStr] + a.remoteConnectionMutex.Unlock() + + if found { + md.ConnData = conn.ConnData + } else { + conn = a.newConnection(md.RemoteAddr) + if conn == nil { + log.Error("Failed to dial to remote address: %s", addrStr) + continue + } + + a.remoteConnectionMutex.Lock() + a.remoteConnectionMap[addrStr] = conn + a.remoteConnectionMutex.Unlock() + + md.ConnData = conn.ConnData + + // launch connection routine + a.wg.Add(1) + go a.connectionRoutine(conn) + } + + a.device.SendMsgToPacket(md) + } + } +} + +func (a *UdpAC) SendPacket(pkt *core.Packet, conn *UdpConn) (n int, err error) { + defer func() { + atomic.AddUint64(&a.stats.totalSendBytes, uint64(n)) + atomic.StoreInt64(&conn.ConnData.LastLocalSendTime, time.Now().UnixNano()) + + if !pkt.KeepAfterSend { + a.device.ReleasePoolPacket(pkt) + } + }() + + pktType := core.HeaderTypeToString(pkt.HeaderType) + //log.Debug("Send [%s] packet (%s -> %s): %+v", pktType, conn.ConnData.LocalAddr.String(), conn.ConnData.RemoteAddr.String(), pkt.Content) + log.Info("Send [%s] packet (%s -> %s), %d bytes", pktType, conn.ConnData.LocalAddr.String(), conn.ConnData.RemoteAddr.String(), len(pkt.Content)) + log.Evaluate("Send [%s] packet (%s -> %s, %d bytes)", pktType, conn.ConnData.LocalAddr.String(), conn.ConnData.RemoteAddr.String(), len(pkt.Content)) + return conn.netConn.Write(pkt.Content) +} + +func (a *UdpAC) recvPacketRoutine(conn *UdpConn) { + addrStr := conn.ConnData.RemoteAddr.String() + + defer conn.ConnData.Done() + defer log.Debug("recvPacketRoutine for %s stopped", addrStr) + + log.Debug("recvPacketRoutine for %s started", addrStr) + + for { + select { + case <-conn.ConnData.StopSignal: + return + + default: + } + + // udp recv, blocking until packet arrives or netConn.Close() + pkt := a.device.AllocatePoolPacket() + n, err := conn.netConn.Read(pkt.Buf[:]) + if err != nil { + a.device.ReleasePoolPacket(pkt) + if n == 0 { + // udp connection closed, it is not an error + return + } + log.Error("Failed to receive from remote address %s (%v)", addrStr, err) + continue + } + + // add total recv bytes + atomic.AddUint64(&a.stats.totalRecvBytes, uint64(n)) + + // check minimal length + if n < core.HeaderSize { + a.device.ReleasePoolPacket(pkt) + log.Error("Received UDP packet from %s is too short, discard", addrStr) + continue + } + + pkt.Content = pkt.Buf[:n] + //log.Trace("receive udp packet (%s -> %s): %+v", conn.ConnData.RemoteAddr.String(), conn.ConnData.LocalAddr.String(), pkt.Content) + + typ, _, err := a.device.RecvPrecheck(pkt) + msgType := core.HeaderTypeToString(typ) + log.Info("Receive [%s] packet (%s -> %s), %d bytes", msgType, addrStr, conn.ConnData.LocalAddr.String(), n) + log.Evaluate("Receive [%s] packet (%s -> %s), %d bytes", msgType, addrStr, conn.ConnData.LocalAddr.String(), n) + if err != nil { + a.device.ReleasePoolPacket(pkt) + log.Warning("Receive [%s] packet (%s -> %s), precheck error: %v", msgType, addrStr, conn.ConnData.LocalAddr.String(), err) + log.Evaluate("Receive [%s] packet (%s -> %s) precheck error: %v", msgType, addrStr, conn.ConnData.LocalAddr.String(), err) + continue + } + + atomic.StoreInt64(&conn.ConnData.LastLocalRecvTime, time.Now().UnixNano()) + + conn.ConnData.ForwardInboundPacket(pkt) + } +} + +func (a *UdpAC) connectionRoutine(conn *UdpConn) { + addrStr := conn.ConnData.RemoteAddr.String() + + defer a.wg.Done() + defer log.Debug("Connection routine: %s stopped", addrStr) + + log.Debug("Connection routine: %s started", addrStr) + + // stop receiving packets and clean up + defer func() { + a.remoteConnectionMutex.Lock() + delete(a.remoteConnectionMap, addrStr) + a.remoteConnectionMutex.Unlock() + + conn.Close() + }() + + for { + select { + case <-a.signals.stop: + return + + case <-conn.ConnData.SetTimeoutSignal: + if conn.ConnData.TimeoutMs <= 0 { + log.Debug("Connection routine closed immediately") + return + } + + case <-time.After(time.Duration(conn.ConnData.TimeoutMs) * time.Millisecond): + // timeout, quit routine + log.Debug("Connection routine idle timeout") + return + + case pkt, ok := <-conn.ConnData.SendQueue: + if !ok { + return + } + if pkt == nil { + continue + } + a.SendPacket(pkt, conn) + + case pkt, ok := <-conn.ConnData.RecvQueue: + if !ok { + return + } + if pkt == nil { + continue + } + log.Debug("Received udp packet len [%d] from addr: %s\n", len(pkt.Content), addrStr) + + if pkt.HeaderType == core.NHP_KPL { + a.device.ReleasePoolPacket(pkt) + log.Info("Receive [NHP_KPL] message (%s -> %s)", addrStr, conn.ConnData.LocalAddr.String()) + continue + } + + if a.device.IsTransactionResponse(pkt.HeaderType) { + // forward to a specific transaction + transactionId := pkt.Counter() + transaction := a.device.FindLocalTransaction(transactionId) + if transaction != nil { + transaction.NextPacketCh <- pkt + continue + } + } + + pd := &core.PacketData{ + BasePacket: pkt, + ConnData: conn.ConnData, + InitTime: atomic.LoadInt64(&conn.ConnData.LastLocalRecvTime), + } + // generic receive + a.device.RecvPacketToMsg(pd) + + case <-conn.ConnData.BlockSignal: + log.Critical("blocking address %s", addrStr) + return + } + } +} + +func (a *UdpAC) recvMessageRoutine() { + defer a.wg.Done() + defer log.Info("recvMessageRoutine stopped") + + log.Info("recvMessageRoutine started") + + for { + select { + case <-a.signals.stop: + return + + case ppd, ok := <-a.recvMsgCh: + if !ok { + return + } + if ppd == nil { + continue + } + + switch ppd.HeaderType { + case core.NHP_AOP: + // deal with NHP_AOP message + go a.HandleUdpACOperations(ppd) + } + } + } +} + +// keep interaction between ac and server in certain time interval to keep outwards ip path active +func (a *UdpAC) maintainServerConnectionRoutine() { + defer a.wg.Done() + defer log.Info("maintainServerConnectionRoutine stopped") + + log.Info("maintainServerConnectionRoutine started") + + // reset iptables before exiting + defer a.iptables.ResetAllInput() + + var discoveryRoutineWg sync.WaitGroup + defer discoveryRoutineWg.Wait() + + for { + // make a local copy of servers then iterate because next operations are time consuming (too long to use locked iteration) + a.serverPeerMutex.Lock() + var serverCount int32 = int32(len(a.serverPeerMap)) + discoveryQuitArr := make([]chan struct{}, 0, serverCount) + discoveryFailStatusArr := make([]*int32, 0, serverCount) + + for _, server := range a.serverPeerMap { + // launch discovery routine for each server + fail := new(int32) + discoveryFailStatusArr = append(discoveryFailStatusArr, fail) + quit := make(chan struct{}) + discoveryQuitArr = append(discoveryQuitArr, quit) + + discoveryRoutineWg.Add(1) + go a.serverDiscovery(server, &discoveryRoutineWg, fail, quit) + } + a.serverPeerMutex.Unlock() + + // check whether all server discovery failed. + // If so, open all blocked input + quitCheck := make(chan struct{}) + discoveryQuitArr = append(discoveryQuitArr, quitCheck) + discoveryRoutineWg.Add(1) + go func() { + defer discoveryRoutineWg.Done() + + for { + select { + case <-a.signals.stop: + return + case <-quitCheck: + return + case <-time.After(MinialServerDiscoveryInterval * time.Second): + var totalFail int32 + for _, status := range discoveryFailStatusArr { + totalFail += atomic.LoadInt32(status) + } + + if totalFail < int32(len(discoveryFailStatusArr)) { + a.iptables.ResetAllInput() + } else { + a.iptables.AcceptAllInput() + } + } + } + }() + + select { + case <-a.signals.stop: + return + case _, ok := <-a.signals.serverMapUpdated: + if !ok { + return + } + // stop all current discovery routines + for _, q := range discoveryQuitArr { + close(q) + } + // continue and restart with new server discovery cycle + } + } +} + +func (a *UdpAC) serverDiscovery(server *core.UdpPeer, discoveryRoutineWg *sync.WaitGroup, serverFailCount *int32, quit <-chan struct{}) { + defer discoveryRoutineWg.Done() + + acId := a.config.ACId + serverAddr := server.HostOrAddr() + server, sendAddr := a.ResolvePeer(server) + if sendAddr == nil { + log.Error("Cannot connect to nil server address") + return + } + + addrStr := sendAddr.String() + + defer log.Info("server discovery sub-routine at %s stopped", serverAddr) + log.Info("server discovery sub-routine at %s started", serverAddr) + + var failCount int + + for { + var lastSendTime int64 + var lastRecvTime int64 + var connected bool + + // find whether connection is already connected + a.remoteConnectionMutex.Lock() + conn, found := a.remoteConnectionMap[addrStr] + a.remoteConnectionMutex.Unlock() + + if found { + // connection based timing + lastSendTime = atomic.LoadInt64(&conn.ConnData.LastLocalSendTime) + lastRecvTime = atomic.LoadInt64(&conn.ConnData.LastLocalRecvTime) + connected = conn.connected.Load() + } else { + // peer based timing + conn = nil + lastSendTime = server.LastSendTime() + lastRecvTime = server.LastRecvTime() + } + + currTime := time.Now().UnixNano() + peerPbk := server.PublicKey() + + // when a server is not connected, try to connect in every ACLocalTransactionResponseTimeoutMs + // when a server is connected when ServerConnectionInterval is reached since last receive, try resend NHP_AOL for maintaining server connection + if !connected || (currTime-lastRecvTime) > int64(ReportToServerInterval*time.Second) { + // send NHP_AOL message to server + aolMsg := &common.ACOnlineMsg{ + ACId: acId, + AuthServiceId: a.config.AuthServiceId, + ResourceIds: a.config.ResourceIds, + } + aolBytes, _ := json.Marshal(aolMsg) + + aolMd := &core.MsgData{ + RemoteAddr: sendAddr.(*net.UDPAddr), + HeaderType: core.NHP_AOL, + TransactionId: a.device.NextCounterIndex(), + Compress: true, + PeerPk: peerPbk, + Message: aolBytes, + ResponseMsgCh: make(chan *core.PacketParserData), + } + + if !a.IsRunning() { + log.Error("ac(%s#%d)[ACOnline] MsgData channel closed or being closed, skip sending", acId, aolMd.TransactionId) + return + } + + a.sendMsgCh <- aolMd // create new connection + server.UpdateSend(currTime) + + // block until transaction completes or timeouts + ppd := <-aolMd.ResponseMsgCh + close(aolMd.ResponseMsgCh) + + var err error + func() { + defer func() { + if err != nil { + if conn != nil { + conn.connected.Store(false) + } + + failCount += 1 + if failCount%ServerDiscoveryRetryBeforeFail == 0 { + atomic.StoreInt32(serverFailCount, 1) + // remove failed connection + a.remoteConnectionMutex.Lock() + conn = a.remoteConnectionMap[addrStr] + if conn != nil { + log.Info("server discovery failed, close local connection: %s", conn.ConnData.LocalAddr.String()) + delete(a.remoteConnectionMap, addrStr) + } + a.remoteConnectionMutex.Unlock() + conn.Close() + } + log.Error("ac(%s#%d)[ACOnline] reporting to server %s failed", acId, aolMd.TransactionId, addrStr) + } + + }() + + if ppd.Error != nil { + log.Error("ac(%s#%d)[ACOnline] failed to receive response from server %s: %v", acId, aolMd.TransactionId, addrStr, ppd.Error) + err = ppd.Error + return + } + + if ppd.HeaderType != core.NHP_AAK { + log.Error("ac(%s#%d)[ACOnline] response from server %s has wrong type: %s", acId, aolMd.TransactionId, addrStr, core.HeaderTypeToString(ppd.HeaderType)) + err = common.ErrTransactionRepliedWithWrongType + return + } + + aakMsg := &common.ServerACAckMsg{} + err = json.Unmarshal(ppd.BodyMessage, aakMsg) + if err != nil { + log.Error("ac(%s#%d)[HandleACAck] failed to parse %s message: %v", acId, ppd.SenderTrxId, core.HeaderTypeToString(ppd.HeaderType), err) + return + } + + // server discovery succeeded + failCount = 0 + atomic.StoreInt32(serverFailCount, 0) + a.remoteConnectionMutex.Lock() + conn = a.remoteConnectionMap[addrStr] // conn must be available at this point + conn.connected.Store(true) + conn.externalAddr = aakMsg.ACAddr + a.remoteConnectionMutex.Unlock() + log.Info("ac(%s#%d)[ACOnline] succeed. ac external address is %s, replied by server %s", acId, aolMd.TransactionId, aakMsg.ACAddr, addrStr) + }() + + } else if connected { + if (currTime - lastSendTime) > int64(ServerKeepaliveInterval*time.Second) { + // send NHP_KPL to server if no send happens within ServerKeepaliveInterval + md := &core.MsgData{ + RemoteAddr: sendAddr.(*net.UDPAddr), + HeaderType: core.NHP_KPL, + //PeerPk: peerPbk, // pubkey not needed + TransactionId: a.device.NextCounterIndex(), + } + + a.sendMsgCh <- md // send NHP_KPL to server via existing connection + server.UpdateSend(currTime) + } + } + + select { + case <-a.signals.stop: + return + case <-quit: + return + case <-time.After(MinialServerDiscoveryInterval * time.Second): + // wait for ServerConnectionDiscoveryInterval + } + } +} + +func (a *UdpAC) AddServerPeer(server *core.UdpPeer) { + if server.DeviceType() == core.NHP_SERVER { + a.device.AddPeer(server) + + a.serverPeerMutex.Lock() + a.serverPeerMap[server.PublicKeyBase64()] = server + a.serverPeerMutex.Unlock() + + // renew server connection cycle + if len(a.signals.serverMapUpdated) == 0 { + a.signals.serverMapUpdated <- struct{}{} + } + } +} + +func (a *UdpAC) RemoveServerPeer(serverKey string) { + a.serverPeerMutex.Lock() + beforeSize := len(a.serverPeerMap) + delete(a.serverPeerMap, serverKey) + afterSize := len(a.serverPeerMap) + a.serverPeerMutex.Unlock() + + if beforeSize != afterSize { + // renew server connection cycle + if len(a.signals.serverMapUpdated) == 0 { + a.signals.serverMapUpdated <- struct{}{} + } + } +} + +// if the server uses hostname as destination, find the correct peer with the actual IP address +func (a *UdpAC) ResolvePeer(peer *core.UdpPeer) (*core.UdpPeer, net.Addr) { + addr := peer.SendAddr() + if addr == nil { + return peer, nil + } + + if len(peer.Hostname) == 0 { + // peer with fixed ip, no change + return peer, addr + } + + actualIp := peer.ResolvedIp() + if peer.Ip == actualIp { + // peer with the correct resolved address, no change + return peer, addr + } + + a.serverPeerMutex.Lock() + defer a.serverPeerMutex.Unlock() + for _, p := range a.serverPeerMap { + if p.Ip == actualIp { + p.CopyResolveStatus(peer) + return p, addr + } + } + + return peer, addr +} diff --git a/common/nhpmsg.go b/common/nhpmsg.go index cd800dce..c7a2672a 100644 --- a/common/nhpmsg.go +++ b/common/nhpmsg.go @@ -125,6 +125,7 @@ type ACOpsResultMsg struct { ErrCode string `json:"errCode"` ErrMsg string `json:"errMsg,omitempty"` OpenTime uint32 `json:"opnTime"` + ACToken string `json:"token"` PreAccessAction *PreAccessInfo `json:"preAct"` } @@ -134,6 +135,11 @@ type ACOnlineMsg struct { ACId string `json:"acId,omitempty"` } +type ACRefreshMsg struct { + NhpToken string `json:"nhpToken"` + SourceAddr *NetAddress `json:"srcAddr"` +} + type ServerACAckMsg struct { ErrCode string `json:"errCode"` ErrMsg string `json:"errMsg,omitempty"` diff --git a/common/types.go b/common/types.go index ddb9cab9..144cd4b7 100644 --- a/common/types.go +++ b/common/types.go @@ -76,3 +76,8 @@ type HttpKnockRequest struct { SrcIp string `json:"-"` SrcPort int `json:"-"` } + +type HttpRefreshRequest struct { + Token string `json:"token"` + SrcIp string `json:"srcIp"` +} diff --git a/server/config.go b/server/config.go index a53ed74d..1a55e025 100644 --- a/server/config.go +++ b/server/config.go @@ -1,370 +1,370 @@ -package server - -import ( - "fmt" - "io" - "os" - "path/filepath" - - "github.com/OpenNHP/opennhp/common" - "github.com/OpenNHP/opennhp/core" - "github.com/OpenNHP/opennhp/log" - "github.com/OpenNHP/opennhp/plugins" - "github.com/OpenNHP/opennhp/utils" - - toml "github.com/pelletier/go-toml/v2" -) - -var ( - baseConfigWatch io.Closer - httpConfigWatch io.Closer - acConfigWatch io.Closer - agentConfigWatch io.Closer - resConfigWatch io.Closer - srcipConfigWatch io.Closer - - errLoadConfig = fmt.Errorf("config load error") -) - -type Config struct { - PrivateKeyBase64 string `json:"privateKey"` - ListenIp string `json:"listenIp"` - ListenPort int `json:"listenPort"` - LogLevel int `json:"logLevel"` - Hostname string `json:"hostname"` - DisableAgentValidation bool `json:"disableAgentValidation"` -} - -type HttpConfig struct { - EnableHttp bool - EnableTLS bool - HttpListenIp string - TLSCertFile string - TLSKeyFile string -} - -type Peers struct { - ACs []*core.UdpPeer - Agents []*core.UdpPeer -} - -func (s *UdpServer) loadBaseConfig() error { - // config.toml - fileName := filepath.Join(ExeDirPath, "etc", "config.toml") - if err := s.updateBaseConfig(fileName); err != nil { - // report base config error - return err - } - - baseConfigWatch = utils.WatchFile(fileName, func() { - log.Info("base config: %s has been updated", fileName) - s.updateBaseConfig(fileName) - }) - return nil -} - -func (s *UdpServer) loadHttpConfig() error { - // http.toml - fileName := filepath.Join(ExeDirPath, "etc", "http.toml") - if err := s.updateHttpConfig(fileName); err != nil { - // ignore error - _ = err - } - - httpConfigWatch = utils.WatchFile(fileName, func() { - log.Info("http config: %s has been updated", fileName) - s.updateHttpConfig(fileName) - }) - return nil -} - -func (s *UdpServer) loadPeers() error { - // ac.toml - fileNameAC := filepath.Join(ExeDirPath, "etc", "ac.toml") - if err := s.updateACPeers(fileNameAC); err != nil { - // ignore error - _ = err - } - - acConfigWatch = utils.WatchFile(fileNameAC, func() { - log.Info("ac peer config: %s has been updated", fileNameAC) - s.updateACPeers(fileNameAC) - }) - - // agent.toml - fileNameAgent := filepath.Join(ExeDirPath, "etc", "agent.toml") - if err := s.updateAgentPeers(fileNameAgent); err != nil { - // ignore error - _ = err - } - - agentConfigWatch = utils.WatchFile(fileNameAgent, func() { - log.Info("agent peer config: %s has been updated", fileNameAgent) - s.updateAgentPeers(fileNameAgent) - }) - return nil -} - -func (s *UdpServer) loadResources() error { - // resource.toml - fileName := filepath.Join(ExeDirPath, "etc", "resource.toml") - if err := s.updateResources(fileName); err != nil { - // ignore error - _ = err - } - - resConfigWatch = utils.WatchFile(fileName, func() { - log.Info("resource config: %s has been updated", fileName) - s.updateResources(fileName) - }) - return nil -} - -func (s *UdpServer) loadSourceIps() error { - // srcip.toml - fileName := filepath.Join(ExeDirPath, "etc", "srcip.toml") - if err := s.updateSourceIps(fileName); err != nil { - // ignore error - _ = err - } - - srcipConfigWatch = utils.WatchFile(fileName, func() { - log.Info("src ip config: %s has been updated", fileName) - s.updateSourceIps(fileName) - }) - return nil -} - -func (s *UdpServer) updateBaseConfig(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read base config: %v", err) - } - - var conf Config - if err := toml.Unmarshal(content, &conf); err != nil { - log.Error("failed to unmarshal base config: %v", err) - } - - if s.config == nil { - s.config = &conf - s.log.SetLogLevel(conf.LogLevel) - return err - } - - // update - if s.config.LogLevel != conf.LogLevel { - log.Info("set base log level to %d", conf.LogLevel) - s.log.SetLogLevel(conf.LogLevel) - s.config.LogLevel = conf.LogLevel - } - - if s.config.DisableAgentValidation != conf.DisableAgentValidation { - if s.device != nil { - s.device.SetOption(core.DeviceOptions{ - DisableAgentPeerValidation: conf.DisableAgentValidation, - }) - } - s.config.DisableAgentValidation = conf.DisableAgentValidation - } - - return err -} - -func (s *UdpServer) updateHttpConfig(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read base config: %v", err) - } - - var httpConf HttpConfig - if err := toml.Unmarshal(content, &httpConf); err != nil { - log.Error("failed to unmarshal base config: %v", err) - } - - // update - if httpConf.EnableHttp { - // start http server - if s.httpServer == nil || !s.httpServer.IsRunning() { - if s.httpServer != nil { - // stop old http server - go s.httpServer.Stop() - } - hs := &HttpServer{} - s.httpServer = hs - err = hs.Start(s, &httpConf) - if err != nil { - return err - } - } - } else { - // stop http server - if s.httpServer != nil && s.httpServer.IsRunning() { - go s.httpServer.Stop() - s.httpServer = nil - } - } - - s.httpConfig = &httpConf - return err -} - -func (s *UdpServer) updateACPeers(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read ac peer config: %v", err) - } - - // update - var peers Peers - acPeerMap := make(map[string]*core.UdpPeer) - if err := toml.Unmarshal(content, &peers); err != nil { - log.Error("failed to unmarshal ac peer config: %v", err) - } - for _, p := range peers.ACs { - p.Type = core.NHP_AC - s.device.AddPeer(p) - acPeerMap[p.PublicKeyBase64()] = p - } - - // remove old peers from device - s.acPeerMapMutex.Lock() - defer s.acPeerMapMutex.Unlock() - for pubKey := range s.acPeerMap { - if _, found := acPeerMap[pubKey]; !found { - s.device.RemovePeer(pubKey) - } - } - s.acPeerMap = acPeerMap - - return err -} - -func (s *UdpServer) updateAgentPeers(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read agent peer config: %v", err) - } - - var peers Peers - agentPeerMap := make(map[string]*core.UdpPeer) - if err := toml.Unmarshal(content, &peers); err != nil { - log.Error("failed to unmarshal agent peer config: %v", err) - } - for _, p := range peers.Agents { - p.Type = core.NHP_AGENT - s.device.AddPeer(p) - agentPeerMap[p.PublicKeyBase64()] = p - } - - // remove old peers from device - s.agentPeerMapMutex.Lock() - defer s.agentPeerMapMutex.Unlock() - for pubKey := range s.agentPeerMap { - if _, found := agentPeerMap[pubKey]; !found { - s.device.RemovePeer(pubKey) - } - } - s.agentPeerMap = agentPeerMap - - return err -} - -func (s *UdpServer) updateResources(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read resource config: %v", err) - } - - // update - aspMap := make(common.AuthSvcProviderMap) - if err := toml.Unmarshal(content, &aspMap); err != nil { - log.Error("failed to unmarshal resource config: %v", err) - } - - for aspId, aspData := range aspMap { - aspData.AuthSvcId = aspId - if len(aspData.PluginPath) > 0 { - h := plugins.ReadPluginHandler(aspData.PluginPath) - if h != nil { - s.LoadPlugin(aspId, h) - } - } - - for resId, res := range aspData.ResourceGroups { - // Note: res is a pointer, so we can update its value - res.AuthServiceId = aspId - res.ResourceId = resId - } - } - - s.authServiceMapMutex.Lock() - defer s.authServiceMapMutex.Unlock() - s.authServiceMap = aspMap - - return err -} - -func (s *UdpServer) updateSourceIps(file string) (err error) { - utils.CatchPanicThenRun(func() { - err = errLoadConfig - }) - - content, err := os.ReadFile(file) - if err != nil { - log.Error("failed to read src ip config: %v", err) - } - - // update - srcIpMap := make(map[string][]*common.NetAddress) - if err := toml.Unmarshal(content, &srcIpMap); err != nil { - log.Error("failed to unmarshal src ip config: %v", err) - } - - s.srcIpAssociatedAddrMapMutex.Lock() - defer s.srcIpAssociatedAddrMapMutex.Unlock() - s.srcIpAssociatedAddrMap = srcIpMap - - return err -} - -func (s *UdpServer) StopConfigWatch() { - if baseConfigWatch != nil { - baseConfigWatch.Close() - } - if httpConfigWatch != nil { - httpConfigWatch.Close() - } - if acConfigWatch != nil { - acConfigWatch.Close() - } - if agentConfigWatch != nil { - agentConfigWatch.Close() - } - if resConfigWatch != nil { - resConfigWatch.Close() - } - if srcipConfigWatch != nil { - srcipConfigWatch.Close() - } -} +package server + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/OpenNHP/opennhp/common" + "github.com/OpenNHP/opennhp/core" + "github.com/OpenNHP/opennhp/log" + "github.com/OpenNHP/opennhp/plugins" + "github.com/OpenNHP/opennhp/utils" + + toml "github.com/pelletier/go-toml/v2" +) + +var ( + baseConfigWatch io.Closer + httpConfigWatch io.Closer + acConfigWatch io.Closer + agentConfigWatch io.Closer + resConfigWatch io.Closer + srcipConfigWatch io.Closer + + errLoadConfig = fmt.Errorf("config load error") +) + +type Config struct { + PrivateKeyBase64 string `json:"privateKey"` + ListenIp string `json:"listenIp"` + ListenPort int `json:"listenPort"` + LogLevel int `json:"logLevel"` + Hostname string `json:"hostname"` + DisableAgentValidation bool `json:"disableAgentValidation"` +} + +type HttpConfig struct { + EnableHttp bool + EnableTLS bool + HttpListenIp string + TLSCertFile string + TLSKeyFile string +} + +type Peers struct { + ACs []*core.UdpPeer + Agents []*core.UdpPeer +} + +func (s *UdpServer) loadBaseConfig() error { + // config.toml + fileName := filepath.Join(ExeDirPath, "etc", "config.toml") + if err := s.updateBaseConfig(fileName); err != nil { + // report base config error + return err + } + + baseConfigWatch = utils.WatchFile(fileName, func() { + log.Info("base config: %s has been updated", fileName) + s.updateBaseConfig(fileName) + }) + return nil +} + +func (s *UdpServer) loadHttpConfig() error { + // http.toml + fileName := filepath.Join(ExeDirPath, "etc", "http.toml") + if err := s.updateHttpConfig(fileName); err != nil { + // ignore error + _ = err + } + + httpConfigWatch = utils.WatchFile(fileName, func() { + log.Info("http config: %s has been updated", fileName) + s.updateHttpConfig(fileName) + }) + return nil +} + +func (s *UdpServer) loadPeers() error { + // ac.toml + fileNameAC := filepath.Join(ExeDirPath, "etc", "ac.toml") + if err := s.updateACPeers(fileNameAC); err != nil { + // ignore error + _ = err + } + + acConfigWatch = utils.WatchFile(fileNameAC, func() { + log.Info("ac peer config: %s has been updated", fileNameAC) + s.updateACPeers(fileNameAC) + }) + + // agent.toml + fileNameAgent := filepath.Join(ExeDirPath, "etc", "agent.toml") + if err := s.updateAgentPeers(fileNameAgent); err != nil { + // ignore error + _ = err + } + + agentConfigWatch = utils.WatchFile(fileNameAgent, func() { + log.Info("agent peer config: %s has been updated", fileNameAgent) + s.updateAgentPeers(fileNameAgent) + }) + return nil +} + +func (s *UdpServer) loadResources() error { + // resource.toml + fileName := filepath.Join(ExeDirPath, "etc", "resource.toml") + if err := s.updateResources(fileName); err != nil { + // ignore error + _ = err + } + + resConfigWatch = utils.WatchFile(fileName, func() { + log.Info("resource config: %s has been updated", fileName) + s.updateResources(fileName) + }) + return nil +} + +func (s *UdpServer) loadSourceIps() error { + // srcip.toml + fileName := filepath.Join(ExeDirPath, "etc", "srcip.toml") + if err := s.updateSourceIps(fileName); err != nil { + // ignore error + _ = err + } + + srcipConfigWatch = utils.WatchFile(fileName, func() { + log.Info("src ip config: %s has been updated", fileName) + s.updateSourceIps(fileName) + }) + return nil +} + +func (s *UdpServer) updateBaseConfig(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read base config: %v", err) + } + + var conf Config + if err := toml.Unmarshal(content, &conf); err != nil { + log.Error("failed to unmarshal base config: %v", err) + } + + if s.config == nil { + s.config = &conf + s.log.SetLogLevel(conf.LogLevel) + return err + } + + // update + if s.config.LogLevel != conf.LogLevel { + log.Info("set base log level to %d", conf.LogLevel) + s.log.SetLogLevel(conf.LogLevel) + s.config.LogLevel = conf.LogLevel + } + + if s.config.DisableAgentValidation != conf.DisableAgentValidation { + if s.device != nil { + s.device.SetOption(core.DeviceOptions{ + DisableAgentPeerValidation: conf.DisableAgentValidation, + }) + } + s.config.DisableAgentValidation = conf.DisableAgentValidation + } + + return err +} + +func (s *UdpServer) updateHttpConfig(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read http config: %v", err) + } + + var httpConf HttpConfig + if err := toml.Unmarshal(content, &httpConf); err != nil { + log.Error("failed to unmarshal http config: %v", err) + } + + // update + if httpConf.EnableHttp { + // start http server + if s.httpServer == nil || !s.httpServer.IsRunning() { + if s.httpServer != nil { + // stop old http server + go s.httpServer.Stop() + } + hs := &HttpServer{} + s.httpServer = hs + err = hs.Start(s, &httpConf) + if err != nil { + return err + } + } + } else { + // stop http server + if s.httpServer != nil && s.httpServer.IsRunning() { + go s.httpServer.Stop() + s.httpServer = nil + } + } + + s.httpConfig = &httpConf + return err +} + +func (s *UdpServer) updateACPeers(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read ac peer config: %v", err) + } + + // update + var peers Peers + acPeerMap := make(map[string]*core.UdpPeer) + if err := toml.Unmarshal(content, &peers); err != nil { + log.Error("failed to unmarshal ac peer config: %v", err) + } + for _, p := range peers.ACs { + p.Type = core.NHP_AC + s.device.AddPeer(p) + acPeerMap[p.PublicKeyBase64()] = p + } + + // remove old peers from device + s.acPeerMapMutex.Lock() + defer s.acPeerMapMutex.Unlock() + for pubKey := range s.acPeerMap { + if _, found := acPeerMap[pubKey]; !found { + s.device.RemovePeer(pubKey) + } + } + s.acPeerMap = acPeerMap + + return err +} + +func (s *UdpServer) updateAgentPeers(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read agent peer config: %v", err) + } + + var peers Peers + agentPeerMap := make(map[string]*core.UdpPeer) + if err := toml.Unmarshal(content, &peers); err != nil { + log.Error("failed to unmarshal agent peer config: %v", err) + } + for _, p := range peers.Agents { + p.Type = core.NHP_AGENT + s.device.AddPeer(p) + agentPeerMap[p.PublicKeyBase64()] = p + } + + // remove old peers from device + s.agentPeerMapMutex.Lock() + defer s.agentPeerMapMutex.Unlock() + for pubKey := range s.agentPeerMap { + if _, found := agentPeerMap[pubKey]; !found { + s.device.RemovePeer(pubKey) + } + } + s.agentPeerMap = agentPeerMap + + return err +} + +func (s *UdpServer) updateResources(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read resource config: %v", err) + } + + // update + aspMap := make(common.AuthSvcProviderMap) + if err := toml.Unmarshal(content, &aspMap); err != nil { + log.Error("failed to unmarshal resource config: %v", err) + } + + for aspId, aspData := range aspMap { + aspData.AuthSvcId = aspId + if len(aspData.PluginPath) > 0 { + h := plugins.ReadPluginHandler(aspData.PluginPath) + if h != nil { + s.LoadPlugin(aspId, h) + } + } + + for resId, res := range aspData.ResourceGroups { + // Note: res is a pointer, so we can update its value + res.AuthServiceId = aspId + res.ResourceId = resId + } + } + + s.authServiceMapMutex.Lock() + defer s.authServiceMapMutex.Unlock() + s.authServiceMap = aspMap + + return err +} + +func (s *UdpServer) updateSourceIps(file string) (err error) { + utils.CatchPanicThenRun(func() { + err = errLoadConfig + }) + + content, err := os.ReadFile(file) + if err != nil { + log.Error("failed to read src ip config: %v", err) + } + + // update + srcIpMap := make(map[string][]*common.NetAddress) + if err := toml.Unmarshal(content, &srcIpMap); err != nil { + log.Error("failed to unmarshal src ip config: %v", err) + } + + s.srcIpAssociatedAddrMapMutex.Lock() + defer s.srcIpAssociatedAddrMapMutex.Unlock() + s.srcIpAssociatedAddrMap = srcIpMap + + return err +} + +func (s *UdpServer) StopConfigWatch() { + if baseConfigWatch != nil { + baseConfigWatch.Close() + } + if httpConfigWatch != nil { + httpConfigWatch.Close() + } + if acConfigWatch != nil { + acConfigWatch.Close() + } + if agentConfigWatch != nil { + agentConfigWatch.Close() + } + if resConfigWatch != nil { + resConfigWatch.Close() + } + if srcipConfigWatch != nil { + srcipConfigWatch.Close() + } +}