GadElKareem

Golang IP Rate limiter with automated proxy ranges and bot reverse DNS check

package main

package main

import (
	"bufio"
	"github.com/astaxie/beego"
	"github.com/astaxie/beego/context"
	"github.com/astaxie/beego/logs"
	"github.com/ulule/limiter"
	"github.com/ulule/limiter/drivers/store/memory"
	"net"
	"net/http"
	"strconv"
	"strings"
	"sync"
)

type rateLimiter struct {
	generalLimiter *limiter.Limiter
	loginLimiter   *limiter.Limiter
	proxyBlocks    []*net.IPNet
	botIpCacheMu   sync.RWMutex
	botIpCache     map[string]bool
}

var (
	proxyUrls = []string{
		"https://www.cloudflare.com/ips-v4",
		"https://www.cloudflare.com/ips-v6",
	}
	proxyCidrs = []string{
		"127.0.0.1/32",
		"::1/128",
	}
)

func main() {
	r := &rateLimiter{
		botIpCache: make(map[string]bool),
	}
	r.botIpCache["127.0.0.1"] = true
	r.botIpCache["::1"] = true

	r.generalLimiter = newLimiter("2-S")
	r.loginLimiter = newLimiter("2-M")

	r.initProxyBlocks()

	//More on Beego filters here https://beego.me/docs/mvc/controller/filter.md
	beego.InsertFilter("/*", beego.BeforeRouter, func(c *context.Context) {
		rateLimit(r, c)
	}, true)

	//refer to https://beego.me/docs/mvc/controller/errors.md for error handling
	beego.ErrorHandler("429", func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusTooManyRequests)
		w.Write([]byte("Too Many Requests"))
		return
	})
	beego.Run()
}

func rateLimit(r *rateLimiter, ctx *context.Context) {
	var (
		limiterCtx   limiter.Context
		err          error
		req          = ctx.Request
		trustedProxy = false
	)

	ip := r.generalLimiter.GetIP(req)
	for _, block := range r.proxyBlocks {
		if block.Contains(ip) {
			trustedProxy = true
			ip = r.generalLimiter.GetIP(req)
			break
		}
	}
	//remove Forwarded for header if unknown proxy
	if !trustedProxy {
		req.Header.Del("X-Forwarded-For")
	}

	if strings.HasPrefix(ctx.Input.URL(), "/login") {
		limiterCtx, err = r.loginLimiter.Get(req.Context(), ip.String())
	} else {
		limiterCtx, err = r.generalLimiter.Get(req.Context(), ip.String())
	}
	if err != nil {
		ctx.Abort(http.StatusInternalServerError, err.Error())
		return
	}

	h := ctx.ResponseWriter.Header()
	h.Add("X-RateLimit-Limit", strconv.FormatInt(limiterCtx.Limit, 10))
	h.Add("X-RateLimit-Remaining", strconv.FormatInt(limiterCtx.Remaining, 10))
	h.Add("X-RateLimit-Reset", strconv.FormatInt(limiterCtx.Reset, 10))

	if limiterCtx.Reached {
		if r.isTrustedBot(ip.String()) {
			//beetak abasha...
			return
		}
		logs.Debug("Too Many Requests from %s on %s", ip, ctx.Input.URL())
		//refer to https://beego.me/docs/mvc/controller/errors.md for error handling
		ctx.Abort(http.StatusTooManyRequests, "429")
		return
	}

}

func (r *rateLimiter) initProxyBlocks() {
	var blocks []*net.IPNet
	for _, u := range proxyUrls {
		blocks = r.fetchProxyBlocks(u)
		if blocks != nil {
			r.proxyBlocks = append(r.proxyBlocks, blocks...)
		}
	}
	var block *net.IPNet
	for _, c := range proxyCidrs {
		if block = parseCidr(c); block != nil {
			r.proxyBlocks = append(r.proxyBlocks, block)
		}
	}

}

func (r *rateLimiter) fetchProxyBlocks(urlString string) []*net.IPNet {
	response, err := http.Get(urlString)
	if err != nil {
		logs.Error("http.Get => %v", err)
		return nil
	}
	defer response.Body.Close()

	var (
		line   string
		block  *net.IPNet
		blocks []*net.IPNet
	)
	scanner := bufio.NewScanner(response.Body)
	for scanner.Scan() {
		line = strings.TrimSpace(scanner.Text())
		if line != "" {
			if block = parseCidr(line); block != nil {
				blocks = append(blocks, block)
			}
		}
	}
	if scanner.Err() != nil {
		logs.Error("Error scanning proxy list: %s", scanner.Err())
		return nil
	}
	return blocks
}

func (r *rateLimiter) isTrustedBot(ip string) bool {
	r.botIpCacheMu.RLock()
	if _, exists := r.botIpCache[ip]; exists {
		r.botIpCacheMu.RUnlock()
		return true
	}
	r.botIpCacheMu.RUnlock()
	addr, err := net.LookupAddr(ip)
	if err != nil {
		logs.Error("Error getting host for ip %s Error: %v", ip, err)
		return false
	}
	for _, h := range addr {
		if strings.HasSuffix(h, "google.com.") ||
			strings.HasSuffix(h, "googlebot.com.") ||
			strings.HasSuffix(h, "msn.com.") ||
			strings.HasSuffix(h, "yandex.com.") ||
			strings.HasSuffix(h, "yandex.net.") ||
			strings.HasSuffix(h, "baidu.com.") ||
			strings.HasSuffix(h, "yahoo.com.") {
			ips, err := net.LookupHost(h)
			if err != nil {
				logs.Error("Error getting host for ip %s Error: %v", ip, err)
				return false
			}
			for _, _ip := range ips {
				if ip == _ip {
					r.botIpCacheMu.Lock()
					r.botIpCache[ip] = true
					r.botIpCacheMu.Unlock()
					return true
				}
			}
		}
	}
	return false
}

func parseCidr(s string) *net.IPNet {
	_, block, err := net.ParseCIDR(s)
	if err != nil {
		logs.Error("Error parsing proxy block: %v", err)
		return nil
	}
	return block
}

func PanicOnError(e error) {
	if e != nil {
		panic(e)
	}
}

func newLimiter(rate string) *limiter.Limiter {
	r, err := limiter.NewRateFromFormatted(rate)
	PanicOnError(err)
	return limiter.New(memory.NewStore(), r, limiter.WithTrustForwardHeader(true))
}

Fork it