diff --git a/README.md b/README.md index 53bbfaf..dbc136b 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ "1.2.3.4":true, "114.114.114.114":true }, + "blacklist_file": "/yourblacklistfliepath", "targets": [ 目标配置 ] @@ -46,7 +47,8 @@ 3. enable_regexp为是否开启正则表达式模式,后面有解释 4. first_packet_timeout为等待客户端第一个数据包的超时时间(**毫秒**),仅开启正则表达式模式后有效,后面有解释 5. blacklist为黑名单IP,在黑名单里面的IP且为true的时候则直接断开链接。如不需要使用黑名单可留null -5. targets为目标配置数组,看下面 +6. blacklist_file文件为单行单IP黑名单,多行多IP同时黑名单 +7. targets为目标配置数组,看下面 #### 目标配置 目标配置有两种模式:**普通模式**和**正则模式**。 @@ -74,6 +76,7 @@ "1.2.3.4":true, "114.114.114.114":true }, + "blacklist_file": "/yourblacklistfliepath", "targets": [ { "address": "127.0.0.1:80" @@ -89,6 +92,7 @@ "1.2.3.4":true, "114.114.114.114":true }, + "blacklist_file": "/yourblacklistfliepath", "targets": [ { "regexp": "^(GET|POST|HEAD|DELETE|PUT|CONNECT|OPTIONS|TRACE)", diff --git a/config.go b/config.go index c2028ce..11a8bcc 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,17 @@ package main import ( + "bufio" "encoding/json" "flag" "fmt" "github.com/sirupsen/logrus" "io/ioutil" + "os" "regexp" + "strings" + "sync" + "time" ) type configStructure struct { @@ -24,10 +29,12 @@ type ruleStructure struct { Address string `json:"address"` } `json:"targets"` FirstPacketTimeout uint64 `json:"first_packet_timeout"` - Blacklist map[string]bool `json:"blacklist"` + BlacklistFile string `json:"blacklist_file"` + blacklistMap map[string]bool `json:"-"` } var config *configStructure +var blacklistMutex = &sync.Mutex{} func init() { cfgPath := flag.String("config", "config.json", "config.json file path") @@ -85,5 +92,68 @@ func (c *ruleStructure) verify() error { v.regexp = r } } + if c.BlacklistFile != "" { + err := loadBlacklist(c.BlacklistFile, &c.blacklistMap) + if err != nil { + return fmt.Errorf("failed to load blacklist: %s", err.Error()) + } + go watchBlacklist(c.BlacklistFile, &c.blacklistMap) + } return nil } + +func loadBlacklist(path string, blacklist *map[string]bool) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + newBlacklist := make(map[string]bool) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + ip := strings.TrimSpace(scanner.Text()) + newBlacklist[ip] = true + } + + if err := scanner.Err(); err != nil { + return err + } + + blacklistMutex.Lock() + oldBlacklist := *blacklist + *blacklist = newBlacklist + blacklistMutex.Unlock() + + // 打印出被移除的IP地址 + for ip := range oldBlacklist { + if !newBlacklist[ip] { + logrus.Infof("At %s, IP %s move out the Blacklist", time.Now().Format(time.RFC3339), ip) + } + } + + // 打印出新添加的IP地址 + for ip := range newBlacklist { + if !oldBlacklist[ip] { + logrus.Infof("At %s, IP %s add to the Blacklist", time.Now().Format(time.RFC3339), ip) + } + } + + return nil +} + + +func watchBlacklist(path string, blacklist *map[string]bool) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := loadBlacklist(path, blacklist) + if err != nil { + logrus.Errorf("failed to reload blacklist: %s", err.Error()) + } + } + } +} diff --git a/config.json b/config.json index d72ad7f..0aea63d 100644 --- a/config.json +++ b/config.json @@ -7,6 +7,7 @@ "enable_regexp": false, "first_packet_timeout": 5000, "blacklist": null, + "blacklist_file": "/yourblacklistfliepath", "targets": [ { "regexp": "^(GET|POST|HEAD|DELETE|PUT|CONNECT|OPTIONS|TRACE)", @@ -19,4 +20,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/core.go b/core.go index 5d42769..149f5b4 100644 --- a/core.go +++ b/core.go @@ -10,6 +10,10 @@ import ( "time" ) +// 定义一个全局的流量统计map +var trafficMap = make(map[string]int64) +var trafficMutex = &sync.Mutex{} + func listen(rule *ruleStructure, wg *sync.WaitGroup) { defer wg.Done() //监听 @@ -28,10 +32,13 @@ func listen(rule *ruleStructure, wg *sync.WaitGroup) { continue } //判断黑名单 - if len(rule.Blacklist) != 0 { + blacklistMutex.Lock() + blacklist := rule.blacklistMap + blacklistMutex.Unlock() + if len(blacklist) != 0 { clientIP := conn.RemoteAddr().String() clientIP = clientIP[0:strings.LastIndex(clientIP, ":")] - if rule.Blacklist[clientIP] { + if blacklist[clientIP] { logrus.Infof("[%s] disconnected ip in blacklist: %s", rule.Name, clientIP) conn.Close() continue @@ -50,76 +57,118 @@ func handleNormal(conn net.Conn, rule *ruleStructure) { defer conn.Close() var target net.Conn - //正常模式下挨个连接直到成功连接 + var targetAddress string for _, v := range rule.Targets { c, err := net.Dial("tcp", v.Address) if err != nil { - logrus.Errorf("[%s] try to handle connection (%s) failed because target (%s) connected failed, try next target.", + logrus.Errorf("[%s] try to handle connection %s failed because target %s connected failed, try next target.", rule.Name, conn.RemoteAddr(), v.Address) continue } target = c + targetAddress = v.Address break } if target == nil { - logrus.Errorf("[%s] unable to handle connection (%s) because all targets connected failed", + logrus.Errorf("[%s] unable to handle connection %s because all targets connected failed", rule.Name, conn.RemoteAddr()) return } - logrus.Debugf("[%s] handle connection (%s) to target (%s)", rule.Name, conn.RemoteAddr(), target.RemoteAddr()) + logrus.Debugf("[%s] handle connection %s to target %s", rule.Name, conn.RemoteAddr(), target.RemoteAddr()) defer target.Close() - //io桥 - go io.Copy(conn, target) - io.Copy(target, conn) + var wg sync.WaitGroup + wg.Add(2) + + var traffic1, traffic2 int64 + ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) + go func() { + defer wg.Done() + traffic1 = copyWithTrafficCount(conn, target) + trafficMutex.Lock() + trafficMap[ip] += traffic1 + trafficMutex.Unlock() + }() + go func() { + defer wg.Done() + traffic2 = copyWithTrafficCount(target, conn) + trafficMutex.Lock() + trafficMap[ip] += traffic2 + trafficMutex.Unlock() + }() + + wg.Wait() + + trafficMutex.Lock() + logrus.Infof("[%s] %s to target %s: This connection traffic: %.2f MB, Total traffic: %.2f MB", rule.Name, conn.RemoteAddr().String(), targetAddress, float64(traffic1 + traffic2) / (1024 * 1024), float64(trafficMap[ip]) / (1024 * 1024)) + trafficMutex.Unlock() } func handleRegexp(conn net.Conn, rule *ruleStructure) { defer conn.Close() - //正则模式下需要客户端的第一个数据包判断特征,所以需要设置一个超时 conn.SetReadDeadline(time.Now().Add(time.Millisecond * time.Duration(rule.FirstPacketTimeout))) - //获取第一个数据包 firstPacket, err := waitFirstPacket(conn) if err != nil { - logrus.Errorf("[%s] unable to handle connection (%s) because failed to get first packet : %s", + logrus.Errorf("[%s] unable to handle connection %s because failed to get first packet : %s", rule.Name, conn.RemoteAddr(), err.Error()) return } var target net.Conn - //挨个匹配正则 + var targetAddress string for _, v := range rule.Targets { if !v.regexp.Match(firstPacket) { continue } c, err := net.Dial("tcp", v.Address) if err != nil { - logrus.Errorf("[%s] try to handle connection (%s) failed because target (%s) connected failed, try next match target.", + logrus.Errorf("[%s] try to handle connection %s failed because target %s connected failed, try next match target.", rule.Name, conn.RemoteAddr(), v.Address) continue } target = c + targetAddress = v.Address break } if target == nil { - logrus.Errorf("[%s] unable to handle connection (%s) because no match target", + logrus.Errorf("[%s] unable to handle connection %s because no match target", rule.Name, conn.RemoteAddr()) return } - logrus.Debugf("[%s] handle connection (%s) to target (%s)", rule.Name, conn.RemoteAddr(), target.RemoteAddr()) - //匹配到了,去除掉刚才设定的超时 + logrus.Debugf("[%s] handle connection %s to target %s", rule.Name, conn.RemoteAddr(), target.RemoteAddr()) conn.SetReadDeadline(time.Time{}) - //把第一个数据包发送给目标 io.Copy(target, bytes.NewReader(firstPacket)) defer target.Close() - //io桥 - go io.Copy(conn, target) - io.Copy(target, conn) + var wg sync.WaitGroup + wg.Add(2) + + var traffic1, traffic2 int64 + ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) + go func() { + defer wg.Done() + traffic1 = copyWithTrafficCount(conn, target) + trafficMutex.Lock() + trafficMap[ip] += traffic1 + trafficMutex.Unlock() + }() + go func() { + defer wg.Done() + traffic2 = copyWithTrafficCount(target, conn) + trafficMutex.Lock() + trafficMap[ip] += traffic2 + trafficMutex.Unlock() + }() + + wg.Wait() + + trafficMutex.Lock() + logrus.Infof("[%s] %s to target %s: This connection traffic: %.2f MB, Total traffic: %.2f MB", rule.Name, conn.RemoteAddr().String(), targetAddress, float64(traffic1 + traffic2) / (1024 * 1024), float64(trafficMap[ip]) / (1024 * 1024)) + trafficMutex.Unlock() } func waitFirstPacket(conn net.Conn) ([]byte, error) { @@ -129,4 +178,29 @@ func waitFirstPacket(conn net.Conn) ([]byte, error) { return nil, err } return buf[:n], nil -} \ No newline at end of file +} + +func copyWithTrafficCount(dst io.Writer, src io.Reader) int64 { + buf := make([]byte, 32*1024) + var traffic int64 = 0 + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + traffic += int64(nw) + } + if ew != nil { + break + } + if nr != nw { + logrus.Errorf("partial write") + break + } + } + if er != nil { + break + } + } + return traffic +} diff --git a/main.go b/main.go index dd52ffe..9e3cac4 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,7 @@ import ( ) const ( - VERSION = "2.0" + VERSION = "2.1" ) func main() {