Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加热更新的blacklist_file和流量统计 #16

Open
wants to merge 10 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"1.2.3.4":true,
"114.114.114.114":true
},
"blacklist_file": "/yourblacklistfliepath",
"targets": [
目标配置
]
Expand All @@ -46,7 +47,8 @@
3. enable_regexp为是否开启正则表达式模式,后面有解释
4. first_packet_timeout为等待客户端第一个数据包的超时时间(**毫秒**),仅开启正则表达式模式后有效,后面有解释
5. blacklist为黑名单IP,在黑名单里面的IP且为true的时候则直接断开链接。如不需要使用黑名单可留null
5. targets为目标配置数组,看下面
6. blacklist_file文件为单行单IP黑名单,多行多IP同时黑名单
7. targets为目标配置数组,看下面

#### 目标配置
目标配置有两种模式:**普通模式**和**正则模式**。
Expand Down Expand Up @@ -74,6 +76,7 @@
"1.2.3.4":true,
"114.114.114.114":true
},
"blacklist_file": "/yourblacklistfliepath",
"targets": [
{
"address": "127.0.0.1:80"
Expand All @@ -89,6 +92,7 @@
"1.2.3.4":true,
"114.114.114.114":true
},
"blacklist_file": "/yourblacklistfliepath",
"targets": [
{
"regexp": "^(GET|POST|HEAD|DELETE|PUT|CONNECT|OPTIONS|TRACE)",
Expand Down
72 changes: 71 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package main

import (
"bufio"
"encoding/json"
"flag"
"fmt"
"github.com/sirupsen/logrus"
"io/ioutil"
"os"
"regexp"
"strings"
"sync"
"time"
)

type configStructure struct {
Expand All @@ -24,10 +29,12 @@ type ruleStructure struct {
Address string `json:"address"`
} `json:"targets"`
FirstPacketTimeout uint64 `json:"first_packet_timeout"`
Blacklist map[string]bool `json:"blacklist"`
BlacklistFile string `json:"blacklist_file"`
blacklistMap map[string]bool `json:"-"`
}

var config *configStructure
var blacklistMutex = &sync.Mutex{}

func init() {
cfgPath := flag.String("config", "config.json", "config.json file path")
Expand Down Expand Up @@ -85,5 +92,68 @@ func (c *ruleStructure) verify() error {
v.regexp = r
}
}
if c.BlacklistFile != "" {
err := loadBlacklist(c.BlacklistFile, &c.blacklistMap)
if err != nil {
return fmt.Errorf("failed to load blacklist: %s", err.Error())
}
go watchBlacklist(c.BlacklistFile, &c.blacklistMap)
}
return nil
}

func loadBlacklist(path string, blacklist *map[string]bool) error {
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()

newBlacklist := make(map[string]bool)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
ip := strings.TrimSpace(scanner.Text())
newBlacklist[ip] = true
}

if err := scanner.Err(); err != nil {
return err
}

blacklistMutex.Lock()
oldBlacklist := *blacklist
*blacklist = newBlacklist
blacklistMutex.Unlock()

// 打印出被移除的IP地址
for ip := range oldBlacklist {
if !newBlacklist[ip] {
logrus.Infof("At %s, IP %s move out the Blacklist", time.Now().Format(time.RFC3339), ip)
}
}

// 打印出新添加的IP地址
for ip := range newBlacklist {
if !oldBlacklist[ip] {
logrus.Infof("At %s, IP %s add to the Blacklist", time.Now().Format(time.RFC3339), ip)
}
}

return nil
}


func watchBlacklist(path string, blacklist *map[string]bool) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
err := loadBlacklist(path, blacklist)
if err != nil {
logrus.Errorf("failed to reload blacklist: %s", err.Error())
}
}
}
}
3 changes: 2 additions & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"enable_regexp": false,
"first_packet_timeout": 5000,
"blacklist": null,
"blacklist_file": "/yourblacklistfliepath",
"targets": [
{
"regexp": "^(GET|POST|HEAD|DELETE|PUT|CONNECT|OPTIONS|TRACE)",
Expand All @@ -19,4 +20,4 @@
]
}
]
}
}
118 changes: 96 additions & 22 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"time"
)

// 定义一个全局的流量统计map
var trafficMap = make(map[string]int64)
var trafficMutex = &sync.Mutex{}

func listen(rule *ruleStructure, wg *sync.WaitGroup) {
defer wg.Done()
//监听
Expand All @@ -28,10 +32,13 @@ func listen(rule *ruleStructure, wg *sync.WaitGroup) {
continue
}
//判断黑名单
if len(rule.Blacklist) != 0 {
blacklistMutex.Lock()
blacklist := rule.blacklistMap
blacklistMutex.Unlock()
if len(blacklist) != 0 {
clientIP := conn.RemoteAddr().String()
clientIP = clientIP[0:strings.LastIndex(clientIP, ":")]
if rule.Blacklist[clientIP] {
if blacklist[clientIP] {
logrus.Infof("[%s] disconnected ip in blacklist: %s", rule.Name, clientIP)
conn.Close()
continue
Expand All @@ -50,76 +57,118 @@ func handleNormal(conn net.Conn, rule *ruleStructure) {
defer conn.Close()

var target net.Conn
//正常模式下挨个连接直到成功连接
var targetAddress string
for _, v := range rule.Targets {
c, err := net.Dial("tcp", v.Address)
if err != nil {
logrus.Errorf("[%s] try to handle connection (%s) failed because target (%s) connected failed, try next target.",
logrus.Errorf("[%s] try to handle connection %s failed because target %s connected failed, try next target.",
rule.Name, conn.RemoteAddr(), v.Address)
continue
}
target = c
targetAddress = v.Address
break
}
if target == nil {
logrus.Errorf("[%s] unable to handle connection (%s) because all targets connected failed",
logrus.Errorf("[%s] unable to handle connection %s because all targets connected failed",
rule.Name, conn.RemoteAddr())
return
}
logrus.Debugf("[%s] handle connection (%s) to target (%s)", rule.Name, conn.RemoteAddr(), target.RemoteAddr())
logrus.Debugf("[%s] handle connection %s to target %s", rule.Name, conn.RemoteAddr(), target.RemoteAddr())

defer target.Close()

//io桥
go io.Copy(conn, target)
io.Copy(target, conn)
var wg sync.WaitGroup
wg.Add(2)

var traffic1, traffic2 int64
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
go func() {
defer wg.Done()
traffic1 = copyWithTrafficCount(conn, target)
trafficMutex.Lock()
trafficMap[ip] += traffic1
trafficMutex.Unlock()
}()
go func() {
defer wg.Done()
traffic2 = copyWithTrafficCount(target, conn)
trafficMutex.Lock()
trafficMap[ip] += traffic2
trafficMutex.Unlock()
}()

wg.Wait()

trafficMutex.Lock()
logrus.Infof("[%s] %s to target %s: This connection traffic: %.2f MB, Total traffic: %.2f MB", rule.Name, conn.RemoteAddr().String(), targetAddress, float64(traffic1 + traffic2) / (1024 * 1024), float64(trafficMap[ip]) / (1024 * 1024))
trafficMutex.Unlock()
}

func handleRegexp(conn net.Conn, rule *ruleStructure) {
defer conn.Close()

//正则模式下需要客户端的第一个数据包判断特征,所以需要设置一个超时
conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(rule.FirstPacketTimeout)))
//获取第一个数据包
firstPacket, err := waitFirstPacket(conn)
if err != nil {
logrus.Errorf("[%s] unable to handle connection (%s) because failed to get first packet : %s",
logrus.Errorf("[%s] unable to handle connection %s because failed to get first packet : %s",
rule.Name, conn.RemoteAddr(), err.Error())
return
}

var target net.Conn
//挨个匹配正则
var targetAddress string
for _, v := range rule.Targets {
if !v.regexp.Match(firstPacket) {
continue
}
c, err := net.Dial("tcp", v.Address)
if err != nil {
logrus.Errorf("[%s] try to handle connection (%s) failed because target (%s) connected failed, try next match target.",
logrus.Errorf("[%s] try to handle connection %s failed because target %s connected failed, try next match target.",
rule.Name, conn.RemoteAddr(), v.Address)
continue
}
target = c
targetAddress = v.Address
break
}
if target == nil {
logrus.Errorf("[%s] unable to handle connection (%s) because no match target",
logrus.Errorf("[%s] unable to handle connection %s because no match target",
rule.Name, conn.RemoteAddr())
return
}

logrus.Debugf("[%s] handle connection (%s) to target (%s)", rule.Name, conn.RemoteAddr(), target.RemoteAddr())
//匹配到了,去除掉刚才设定的超时
logrus.Debugf("[%s] handle connection %s to target %s", rule.Name, conn.RemoteAddr(), target.RemoteAddr())
conn.SetReadDeadline(time.Time{})
//把第一个数据包发送给目标
io.Copy(target, bytes.NewReader(firstPacket))

defer target.Close()

//io桥
go io.Copy(conn, target)
io.Copy(target, conn)
var wg sync.WaitGroup
wg.Add(2)

var traffic1, traffic2 int64
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
go func() {
defer wg.Done()
traffic1 = copyWithTrafficCount(conn, target)
trafficMutex.Lock()
trafficMap[ip] += traffic1
trafficMutex.Unlock()
}()
go func() {
defer wg.Done()
traffic2 = copyWithTrafficCount(target, conn)
trafficMutex.Lock()
trafficMap[ip] += traffic2
trafficMutex.Unlock()
}()

wg.Wait()

trafficMutex.Lock()
logrus.Infof("[%s] %s to target %s: This connection traffic: %.2f MB, Total traffic: %.2f MB", rule.Name, conn.RemoteAddr().String(), targetAddress, float64(traffic1 + traffic2) / (1024 * 1024), float64(trafficMap[ip]) / (1024 * 1024))
trafficMutex.Unlock()
}

func waitFirstPacket(conn net.Conn) ([]byte, error) {
Expand All @@ -129,4 +178,29 @@ func waitFirstPacket(conn net.Conn) ([]byte, error) {
return nil, err
}
return buf[:n], nil
}
}

func copyWithTrafficCount(dst io.Writer, src io.Reader) int64 {
buf := make([]byte, 32*1024)
var traffic int64 = 0
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
traffic += int64(nw)
}
if ew != nil {
break
}
if nr != nw {
logrus.Errorf("partial write")
break
}
}
if er != nil {
break
}
}
return traffic
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

const (
VERSION = "2.0"
VERSION = "2.1"
)

func main() {
Expand Down