Skip to content

permitCount values larger than 1 are now supported #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using StackExchange.Redis;
using System;
using System.Collections.Generic;
using System.Threading.RateLimiting;
using System.Threading.Tasks;

Expand All @@ -18,6 +19,7 @@ internal class RedisConcurrencyManager
local queue_limit = tonumber(@queue_limit)
local try_enqueue = tonumber(@try_enqueue)
local timestamp = tonumber(@current_time)
local requested = tonumber(@permit_count)
-- max seconds it takes to complete a request
local ttl = 60

Expand All @@ -29,10 +31,19 @@ internal class RedisConcurrencyManager
end

local count = redis.call(""zcard"", @rate_limit_key)
local allowed = count < limit
local allowed = count + requested <= limit
local queued = false
local queue_count = 0

local addparams = {}
local remparams = {}
for i=1,requested do
local index = i*2
addparams[index-1]=timestamp
addparams[index]=@unique_id..':'..tostring(i)
remparams[i]=addparams[index]
end

if allowed
then

Expand All @@ -45,23 +56,23 @@ internal class RedisConcurrencyManager
if queue_count == 0 or try_enqueue == 0
then

redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id)
redis.call(""zadd"", @rate_limit_key, unpack(addparams))

if queue_limit > 0
then
-- remove from pending queue
redis.call(""zrem"", @queue_key, @unique_id)
redis.call(""zrem"", @queue_key, unpack(remparams))
end

else
-- queue the current request next in line if we have any requests in the pending queue
allowed = false

queued = queue_count + count < limit + queue_limit
queued = queue_count + count + requested <= limit + queue_limit

if queued
then
redis.call(""zadd"", @queue_key, timestamp, @unique_id)
redis.call(""zadd"", @queue_key, unpack(addparams))
end

end
Expand All @@ -72,23 +83,23 @@ internal class RedisConcurrencyManager
then

queue_count = redis.call(""zcard"", @queue_key)
queued = queue_count < queue_limit
queued = queue_count + requested <= queue_limit

if queued
then
redis.call(""zadd"", @queue_key, timestamp, @unique_id)
redis.call(""zadd"", @queue_key, unpack(addparams))
end

end
end

if allowed
then
redis.call(""hincrby"", @stats_key, 'total_successful', 1)
redis.call(""hincrby"", @stats_key, 'total_successful', requested)
else
if queued == false and try_enqueue == 1
then
redis.call(""hincrby"", @stats_key, 'total_failed', 1)
redis.call(""hincrby"", @stats_key, 'total_failed', requested)
end
end

Expand All @@ -114,7 +125,7 @@ public RedisConcurrencyManager(
StatsRateLimitKey = new RedisKey($"rl:{{{partitionKey}}}:stats");
}

internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string requestId, bool tryEnqueue = false)
internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string requestId, int permitCount, bool tryEnqueue = false)
{
var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds();

Expand All @@ -129,6 +140,7 @@ internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string reques
stats_key = StatsRateLimitKey,
permit_limit = (RedisValue)_options.PermitLimit,
try_enqueue = (RedisValue)(tryEnqueue ? 1 : 0),
permit_count = (RedisValue)permitCount,
queue_limit = (RedisValue)_options.QueueLimit,
current_time = (RedisValue)nowUnixTimeSeconds,
unique_id = (RedisValue)requestId,
Expand All @@ -147,7 +159,7 @@ internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string reques
return result;
}

internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqueue = false)
internal RedisConcurrencyResponse TryAcquireLease(string requestId, int permitCount, bool tryEnqueue = false)
{
var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds();

Expand All @@ -162,6 +174,7 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu
stats_key = StatsRateLimitKey,
permit_limit = (RedisValue)_options.PermitLimit,
try_enqueue = (RedisValue)(tryEnqueue ? 1 : 0),
permit_count = (RedisValue)permitCount,
queue_limit = (RedisValue)_options.QueueLimit,
current_time = (RedisValue)nowUnixTimeSeconds,
unique_id = (RedisValue)requestId,
Expand All @@ -180,22 +193,41 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu
return result;
}

internal void ReleaseLease(string requestId)
internal void ReleaseLease(string requestId, int permitCount)
{
var database = _connectionMultiplexer.GetDatabase();
database.SortedSetRemove(RateLimitKey, requestId);

for (var i = 1; i <= permitCount; i++)
{
database.SortedSetRemove(RateLimitKey, $"{requestId}:{i}");
}
}

internal async Task ReleaseLeaseAsync(string requestId)
internal Task ReleaseLeaseAsync(string requestId, int permitCount)
{
var database = _connectionMultiplexer.GetDatabase();
await database.SortedSetRemoveAsync(RateLimitKey, requestId);
var tasks = new List<Task>(permitCount);

for (var i = 1; i <= permitCount; i++)
{
tasks.Add(database.SortedSetRemoveAsync(RateLimitKey, $"{requestId}:{i}"));
}

return Task.WhenAll(tasks);
}

internal async Task ReleaseQueueLeaseAsync(string requestId)
internal Task ReleaseQueueLeaseAsync(string requestId, int permitCount)
{
var database = _connectionMultiplexer.GetDatabase();
await database.SortedSetRemoveAsync(QueueRateLimitKey, requestId);

var tasks = new List<Task>(permitCount);

for (var i = 1; i <= permitCount; i++)
{
tasks.Add(database.SortedSetRemoveAsync(QueueRateLimitKey, $"{requestId}:{i}"));
}

return Task.WhenAll(tasks);
}

internal RateLimiterStatistics? GetStatistics()
Expand Down
20 changes: 12 additions & 8 deletions src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected override ValueTask<RateLimitLease> AcquireAsyncCore(int permitCount, C
throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit));
}

return AcquireAsyncCoreInternal(cancellationToken);
return AcquireAsyncCoreInternal(permitCount, cancellationToken);
}

protected override RateLimitLease AttemptAcquireCore(int permitCount)
Expand All @@ -85,9 +85,10 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount)
{
Limit = _options.PermitLimit,
RequestId = Guid.NewGuid().ToString(),
PermitCount = permitCount
};

var response = _redisManager.TryAcquireLease(leaseContext.RequestId);
var response = _redisManager.TryAcquireLease(leaseContext.RequestId, permitCount);

leaseContext.Count = response.Count;

Expand All @@ -99,15 +100,16 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount)
return new ConcurrencyLease(isAcquired: false, this, leaseContext);
}

private async ValueTask<RateLimitLease> AcquireAsyncCoreInternal(CancellationToken cancellationToken)
private async ValueTask<RateLimitLease> AcquireAsyncCoreInternal(int permitCount, CancellationToken cancellationToken)
{
var leaseContext = new ConcurencyLeaseContext
{
Limit = _options.PermitLimit,
RequestId = Guid.NewGuid().ToString(),
PermitCount = permitCount
};

var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, tryEnqueue: true);
var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, permitCount, tryEnqueue: true);

leaseContext.Count = response.Count;

Expand Down Expand Up @@ -148,7 +150,7 @@ private void Release(ConcurencyLeaseContext leaseContext)
{
if (leaseContext.RequestId is null) return;

_redisManager.ReleaseLease(leaseContext.RequestId);
_redisManager.ReleaseLease(leaseContext.RequestId, leaseContext.PermitCount);
}

private async Task StartDequeueTimerAsync(PeriodicTimer periodicTimer)
Expand All @@ -170,7 +172,7 @@ private async Task TryDequeueRequestsAsync()
try
{
// The request was canceled while in the pending queue
await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!);
await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount);
}
finally
{
Expand All @@ -182,7 +184,7 @@ private async Task TryDequeueRequestsAsync()
continue;
}

var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!);
var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount);

request.LeaseContext.Count = response.Count;

Expand All @@ -195,7 +197,7 @@ private async Task TryDequeueRequestsAsync()
if (request.TaskCompletionSource?.TrySetResult(pendingLease) == false)
{
// The request was canceled after we acquired the lease
await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!);
await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount);
}
}
finally
Expand Down Expand Up @@ -256,6 +258,8 @@ private sealed class ConcurencyLeaseContext
public long Count { get; set; }

public long Limit { get; set; }

public int PermitCount { get; set; }
}

private sealed class ConcurrencyLease : RateLimitLease
Expand Down
30 changes: 22 additions & 8 deletions src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ internal class RedisFixedWindowManager

private static readonly LuaScript Script = LuaScript.Prepare(
@"local expires_at = tonumber(redis.call(""get"", @expires_at_key))
local current = tonumber(redis.call(""get"", @rate_limit_key))
local requested = tonumber(@increment_amount)
local limit = tonumber(@permit_limit)

if not expires_at or expires_at < tonumber(@current_time) then
-- this is either a brand new window,
Expand All @@ -27,13 +30,19 @@ internal class RedisFixedWindowManager
redis.call(""expireat"", @expires_at_key, @next_expires_at + 1)
-- since the database was updated, return the new value
expires_at = @next_expires_at
current = 0
end

-- now that the window either already exists or it was freshly initialized,
-- increment the counter(`incrby` returns a number)
local current = redis.call(""incrby"", @rate_limit_key, @increment_amount)
local allowed = current + requested <= limit

return { current, expires_at }");
if allowed
then
-- now that the window either already exists or it was freshly initialized,
-- increment the counter(`incrby` returns a number)
current = redis.call(""incrby"", @rate_limit_key, @increment_amount)
end

return { current, expires_at, allowed }");

public RedisFixedWindowManager(
string partitionKey,
Expand All @@ -46,7 +55,7 @@ public RedisFixedWindowManager(
RateLimitExpireKey = new RedisKey($"rl:{{{partitionKey}}}:exp");
}

internal async Task<RedisFixedWindowResponse> TryAcquireLeaseAsync()
internal async Task<RedisFixedWindowResponse> TryAcquireLeaseAsync(int permitCount)
{
var now = DateTimeOffset.UtcNow;
var nowUnixTimeSeconds = now.ToUnixTimeSeconds();
Expand All @@ -59,9 +68,10 @@ internal async Task<RedisFixedWindowResponse> TryAcquireLeaseAsync()
{
rate_limit_key = RateLimitKey,
expires_at_key = RateLimitExpireKey,
permit_limit = _options.PermitLimit,
next_expires_at = (RedisValue)now.Add(_options.Window).ToUnixTimeSeconds(),
current_time = (RedisValue)nowUnixTimeSeconds,
increment_amount = (RedisValue)1D,
increment_amount = (RedisValue)permitCount,
});

var result = new RedisFixedWindowResponse();
Expand All @@ -70,13 +80,14 @@ internal async Task<RedisFixedWindowResponse> TryAcquireLeaseAsync()
{
result.Count = (long)response[0];
result.ExpiresAt = (long)response[1];
result.Allowed = (bool)response[2];
result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds);
}

return result;
}

internal RedisFixedWindowResponse TryAcquireLease()
internal RedisFixedWindowResponse TryAcquireLease(int permitCount)
{
var now = DateTimeOffset.UtcNow;
var nowUnixTimeSeconds = now.ToUnixTimeSeconds();
Expand All @@ -89,9 +100,10 @@ internal RedisFixedWindowResponse TryAcquireLease()
{
rate_limit_key = RateLimitKey,
expires_at_key = RateLimitExpireKey,
permit_limit = _options.PermitLimit,
next_expires_at = (RedisValue)now.Add(_options.Window).ToUnixTimeSeconds(),
current_time = (RedisValue)nowUnixTimeSeconds,
increment_amount = (RedisValue)1D,
increment_amount = (RedisValue)permitCount,
});

var result = new RedisFixedWindowResponse();
Expand All @@ -100,6 +112,7 @@ internal RedisFixedWindowResponse TryAcquireLease()
{
result.Count = (long)response[0];
result.ExpiresAt = (long)response[1];
result.Allowed = (bool)response[2];
result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds);
}

Expand All @@ -112,5 +125,6 @@ internal class RedisFixedWindowResponse
internal long ExpiresAt { get; set; }
internal TimeSpan RetryAfter { get; set; }
internal long Count { get; set; }
internal bool Allowed { get; set; }
}
}
Loading