API Rate Limiting in Go with Redis

Posted on Nov 20, 2023

Last week, I was looking through the access logs of a Go application I built a while back. It has a couple of APIs designed to fetch DNS information from various geographical regions. This app is quite simple in what it does. At the time of development, I didn’t implement any rate limiting for these APIs. While I would like to say it’s because I believe in a trust-based society, the truth is I was just a bit lazy.

The app do not have too many users. Initially, peak traffic was averaging around ~4000 requests per minute. However, in recent weeks, I noticed sudden bursts of traffic, sometimes reaching up to ~20000 requests per minute. As I sifted through the logs, I noticed that most of these bursts were coming from a single IP address. It occurred to me that there’s a system out there that compiles a list of domains and then hits our APIs for each one. I realize I’m partly to blame for this issue, my laziness led me to skip implementing a batch API, which could have mitigated this problem in the first place.

So, in the spirit of being a resilient system designer (and not just because it’s long overdue), let’s finally implement a rate limiting mechanism!

Architecture

Let’s first take a look at the system architecture.

There are two stateless instances of the application running on two VMs behind a load-balancer. Both instances are connected to a Redis instance (in reality, it’s KeyDB).

The load balancer is simply Nginx, which makes it relatively straightforward to implement IP-based rate limiting. Nginx does provide some out of the box solutions which works great.

While this approach would solve the immediate issue of thousands of requests from a single IP, I want to take it a step further. My goal is to apply rate limits based on the user’s subscription plan. For instance,

  • Free plan users get 100 requests/minute.
  • Starter plan users get 3000 requests/minute.
  • and so on.

This means our rate limiting mechanism needs to be aware of the user’s current subscription plan, which can only be determined once the request reaches one of the application instances.

Since the app instances are stateless, they would utilize redis to kepp track of the requests count per user.

Algorithm Overview

There are various rate limiting algorithms out there, each with its own strengths and suited for specific problems. I opted for the Sliding Window Log algorithm. While there are plenty of excellent write-ups available online, I’d like to explain it in my own way.

Basically the algorithm works as follows:

  • We start by defining a window of interest for tracking to requests. This window is typically defined by a specific time frame, in our case it’s 1 minute.
  • For each user, a log of request timestamps is maintained.
  • When a new request arrives, any timestamps in the log that are outside the current window of interest are removed.
  • After clearing outdated timestamps, the new request’s timestamp is appended to the log.
  • We then check the size of the log. If the number of timestamps in the log exceeds the predefined limit (e.g., the maximum number of requests allowed in the time frame), the request is denied.

Algorithm Implementation

Now, to make sure our rate limiting works properly across all our app instances, we need to handle it atomically. Our app instances are stateless, so if two requests from the same user hit different instances at the same time, we want both requests to count towards the rate limit. Otherwise, things get messy and inaccurate.

This is where I find Lua scripts in Redis to be very handy. Sure, Redis transactions could do the job too, but in my experience, they can get pretty complicated fast. Lua scripts keep things neat and tidy, and the entire algorithm logic up in one place.

Here’s the Lua script I am using for rate limiting:

local key = KEYS[1]
local limit = tonumber(ARGV[1])
local windowSize = tonumber(ARGV[2])
local currentTime = tonumber(ARGV[3])

-- Remove all timestamps outside our window of interest
redis.call('ZREMRANGEBYSCORE', key, '-inf', '(' .. currentTime - windowSize)

-- Get the current count from the sorted set
local currentCount = redis.call('ZCARD', key)

if currentCount >= limit then
    return currentCount
end

redis.call('ZADD', key, currentTime, currentTime)
return currentCount

The script is pretty self-explanatory. But still let me break this down into a few key points:

  • We could get the currentTime from Redis, but passing it in makes writing tests easier.
  • if currentCount >= limit then - While in the algorithm overview written above, I mentioned adding the request to the log first and then checking the count, in my actual implementation, I check the count before adding the request. This is because I want only successful requests to count towards the rate limit, not just any API invocation. Users can hit the API thousands of times, but the app will only serve data for the number of requests that their subscription plan allows.

Implementation

Now, let’s dive into the application code that ties everything together.

My app is written in Fiber framework and it uses the go-redis library. However, integrating this into any Go application should be fairly straightforward, regardless of the framework you use.

First, we need to load the lua script:

rateLimitScript := redis.NewScript(luaScript)

This rateLimitScript is of type redis.Script. We can hold on to this rateLimitScript and use it throughout our app’s lifetime.

Next, we need the user information to find out how many requests he is allowed to make. In my setup, this information is extracted from the api-token in the request, into a subscription struct and then the struct is injected into the request context. This process is handled in a Fiber middleware that runs right before the rate limiter. Here’s what the struct looks like:

type Subscription struct {
	UserId            string 
	RequestsPerMinute int64
	// ..other fields
}

We can use the subscription information to construct the inputs for our script:

userDataKey := sub.UserId + "-v1"

now := time.Now().UTC()
duration := time.Minute
expiresAt := now.Add(duration)

key := []string{userDataKey}
values := []interface{}{sub.RequestsPerMinute, duration.Milliseconds(), now.UnixMilli()}

Our window of interest is 1 minute, or 60000 milliseconds. I prefer using UTC time because a minute is the same no matter where you are in the world.

Finally execute the script:

requestsInWindow, err := rateLimitScript.Run(c.Context(), redisClient, key, values).Int64()

We have the number of requests this user has made in the current window (i.e in the last 1 minute). We can compare this against the requests he is allowed to make and decide whether the current request should be allowed or not.

if requestsInWindow >= sub.RequestsPerMinute {
	return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
		"error_code": "rate_limit_exceeded",
	})
}

In case the user is not allowed to make this request, we send a 429 Too Many Requests error.

What I also like to do is send the rate limit headers to the client, like what Twitch or Okta does. These should be simple to implement, since we already have the necessary data.

c.Set("RateLimit-Limit", strconv.FormatInt(sub.RequestsPerMinute, 10))
c.Set("RateLimit-Remaining", strconv.FormatInt(sub.RequestsPerMinute-requestsInWindow, 10))
c.Set("RateLimit-Reset", strconv.FormatInt(expiresAt.Unix(), 10))

And that’s it! We have a fairly simple, but scalable rate limiter in place.

Outro

It’s important to note that this rate limiter is not a generic solution but rather a feature-specific one tailored to the needs of my application. For broader protection, I have a generic IP-based rate limiter set up in the Nginx, which caps requests from any IP at 50,000 per minute. This is because the highest tier subscription plan in my app allows up to 20,000 requests per minute. This has been working well for me so far.

Cheers!


The (almost?) complete code is here:

import (
	"github.com/gofiber/fiber/v2"
	"github.com/redis/go-redis/v9"
	"strconv"
	"time"
)

var luaScript = ``

type RateLimiter struct {
	redisClient         *redis.Client
	rateLimitScript     *redis.Script
}

func NewRateLimiter(
	rdb *redis.Client,
) *RateLimiter {
	return &RateLimiter{
		redisClient:         rdb,
		rateLimitScript:     redis.NewScript(luaScript),
	}
}

func (r *RateLimiter) ApplyRateLimit(c *fiber.Ctx) error {
	sub := c.Locals("subscription")

	userDataKey := sub.UserId + "-v1"

	now := time.Now().UTC()
	duration := time.Minute
	expiresAt := now.Add(duration)

	key := []string{userDataKey}
	values := []interface{}{sub.RequestsPerMinute, duration.Milliseconds(), now.UnixMilli()}

	requestsInWindow, err := r.rateLimitScript.Run(c.Context(), r.redisClient, key, values).Int64()

	if err != nil {
		return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
			"error_code": "unexpected_api_server_error",
		})
	}

	allowedRate := sub.RequestsPerMinute

	c.Set("RateLimit-Limit", strconv.FormatInt(allowedRate, 10))
	c.Set("RateLimit-Remaining", strconv.FormatInt(allowedRate-requestsInWindow, 10))
	c.Set("RateLimit-Reset", strconv.FormatInt(expiresAt.Unix(), 10))

	if requestsInWindow >= allowedRate {
		return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
			"error_code": "rate_limit_exceeded",
		})
	}

	return c.Next()
}