package main
package main
import (
"bufio"
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
"github.com/ulule/limiter"
"github.com/ulule/limiter/drivers/store/memory"
"net"
"net/http"
"strconv"
"strings"
"sync"
)
type rateLimiter struct {
generalLimiter *limiter.Limiter
loginLimiter *limiter.Limiter
proxyBlocks []*net.IPNet
botIpCacheMu sync.RWMutex
botIpCache map[string]bool
}
var (
proxyUrls = []string{
"https://www.cloudflare.com/ips-v4",
"https://www.cloudflare.com/ips-v6",
}
proxyCidrs = []string{
"127.0.0.1/32",
"::1/128",
}
)
func main() {
r := &rateLimiter{
botIpCache: make(map[string]bool),
}
r.botIpCache["127.0.0.1"] = true
r.botIpCache["::1"] = true
r.generalLimiter = newLimiter("2-S")
r.loginLimiter = newLimiter("2-M")
r.initProxyBlocks()
//More on Beego filters here https://beego.me/docs/mvc/controller/filter.md
beego.InsertFilter("/*", beego.BeforeRouter, func(c *context.Context) {
rateLimit(r, c)
}, true)
//refer to https://beego.me/docs/mvc/controller/errors.md for error handling
beego.ErrorHandler("429", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("Too Many Requests"))
return
})
beego.Run()
}
func rateLimit(r *rateLimiter, ctx *context.Context) {
var (
limiterCtx limiter.Context
err error
req = ctx.Request
trustedProxy = false
)
ip := r.generalLimiter.GetIP(req)
for _, block := range r.proxyBlocks {
if block.Contains(ip) {
trustedProxy = true
ip = r.generalLimiter.GetIP(req)
break
}
}
//remove Forwarded for header if unknown proxy
if !trustedProxy {
req.Header.Del("X-Forwarded-For")
}
if strings.HasPrefix(ctx.Input.URL(), "/login") {
limiterCtx, err = r.loginLimiter.Get(req.Context(), ip.String())
} else {
limiterCtx, err = r.generalLimiter.Get(req.Context(), ip.String())
}
if err != nil {
ctx.Abort(http.StatusInternalServerError, err.Error())
return
}
h := ctx.ResponseWriter.Header()
h.Add("X-RateLimit-Limit", strconv.FormatInt(limiterCtx.Limit, 10))
h.Add("X-RateLimit-Remaining", strconv.FormatInt(limiterCtx.Remaining, 10))
h.Add("X-RateLimit-Reset", strconv.FormatInt(limiterCtx.Reset, 10))
if limiterCtx.Reached {
if r.isTrustedBot(ip.String()) {
//beetak abasha...
return
}
logs.Debug("Too Many Requests from %s on %s", ip, ctx.Input.URL())
//refer to https://beego.me/docs/mvc/controller/errors.md for error handling
ctx.Abort(http.StatusTooManyRequests, "429")
return
}
}
func (r *rateLimiter) initProxyBlocks() {
var blocks []*net.IPNet
for _, u := range proxyUrls {
blocks = r.fetchProxyBlocks(u)
if blocks != nil {
r.proxyBlocks = append(r.proxyBlocks, blocks...)
}
}
var block *net.IPNet
for _, c := range proxyCidrs {
if block = parseCidr(c); block != nil {
r.proxyBlocks = append(r.proxyBlocks, block)
}
}
}
func (r *rateLimiter) fetchProxyBlocks(urlString string) []*net.IPNet {
response, err := http.Get(urlString)
if err != nil {
logs.Error("http.Get => %v", err)
return nil
}
defer response.Body.Close()
var (
line string
block *net.IPNet
blocks []*net.IPNet
)
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
line = strings.TrimSpace(scanner.Text())
if line != "" {
if block = parseCidr(line); block != nil {
blocks = append(blocks, block)
}
}
}
if scanner.Err() != nil {
logs.Error("Error scanning proxy list: %s", scanner.Err())
return nil
}
return blocks
}
func (r *rateLimiter) isTrustedBot(ip string) bool {
r.botIpCacheMu.RLock()
if _, exists := r.botIpCache[ip]; exists {
r.botIpCacheMu.RUnlock()
return true
}
r.botIpCacheMu.RUnlock()
addr, err := net.LookupAddr(ip)
if err != nil {
logs.Error("Error getting host for ip %s Error: %v", ip, err)
return false
}
for _, h := range addr {
if strings.HasSuffix(h, "google.com.") ||
strings.HasSuffix(h, "googlebot.com.") ||
strings.HasSuffix(h, "msn.com.") ||
strings.HasSuffix(h, "yandex.com.") ||
strings.HasSuffix(h, "yandex.net.") ||
strings.HasSuffix(h, "baidu.com.") ||
strings.HasSuffix(h, "yahoo.com.") {
ips, err := net.LookupHost(h)
if err != nil {
logs.Error("Error getting host for ip %s Error: %v", ip, err)
return false
}
for _, _ip := range ips {
if ip == _ip {
r.botIpCacheMu.Lock()
r.botIpCache[ip] = true
r.botIpCacheMu.Unlock()
return true
}
}
}
}
return false
}
func parseCidr(s string) *net.IPNet {
_, block, err := net.ParseCIDR(s)
if err != nil {
logs.Error("Error parsing proxy block: %v", err)
return nil
}
return block
}
func PanicOnError(e error) {
if e != nil {
panic(e)
}
}
func newLimiter(rate string) *limiter.Limiter {
r, err := limiter.NewRateFromFormatted(rate)
PanicOnError(err)
return limiter.New(memory.NewStore(), r, limiter.WithTrustForwardHeader(true))
}
Fork it