Skip to content

Commit d49b6b6

Browse files
RedisConcurrencyRateLimiter now supports permitCount parameter
1 parent f9b13f7 commit d49b6b6

File tree

3 files changed

+114
-25
lines changed

3 files changed

+114
-25
lines changed

src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs

+49-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using StackExchange.Redis;
22
using System;
3+
using System.Collections.Generic;
34
using System.Threading.RateLimiting;
45
using System.Threading.Tasks;
56

@@ -18,6 +19,7 @@ internal class RedisConcurrencyManager
1819
local queue_limit = tonumber(@queue_limit)
1920
local try_enqueue = tonumber(@try_enqueue)
2021
local timestamp = tonumber(@current_time)
22+
local requested = tonumber(@permit_count)
2123
-- max seconds it takes to complete a request
2224
local ttl = 60
2325
@@ -29,10 +31,19 @@ internal class RedisConcurrencyManager
2931
end
3032
3133
local count = redis.call(""zcard"", @rate_limit_key)
32-
local allowed = count < limit
34+
local allowed = count + requested <= limit
3335
local queued = false
3436
local queue_count = 0
3537
38+
local addparams = {}
39+
local remparams = {}
40+
for i=1,requested do
41+
local index = i*2
42+
addparams[index-1]=timestamp
43+
addparams[index]=@unique_id..':'..tostring(i)
44+
remparams[i]=addparams[index]
45+
end
46+
3647
if allowed
3748
then
3849
@@ -45,23 +56,23 @@ internal class RedisConcurrencyManager
4556
if queue_count == 0 or try_enqueue == 0
4657
then
4758
48-
redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id)
59+
redis.call(""zadd"", @rate_limit_key, unpack(addparams))
4960
5061
if queue_limit > 0
5162
then
5263
-- remove from pending queue
53-
redis.call(""zrem"", @queue_key, @unique_id)
64+
redis.call(""zrem"", @queue_key, unpack(remparams))
5465
end
5566
5667
else
5768
-- queue the current request next in line if we have any requests in the pending queue
5869
allowed = false
5970
60-
queued = queue_count + count < limit + queue_limit
71+
queued = queue_count + count + requested <= limit + queue_limit
6172
6273
if queued
6374
then
64-
redis.call(""zadd"", @queue_key, timestamp, @unique_id)
75+
redis.call(""zadd"", @queue_key, unpack(addparams))
6576
end
6677
6778
end
@@ -72,23 +83,23 @@ internal class RedisConcurrencyManager
7283
then
7384
7485
queue_count = redis.call(""zcard"", @queue_key)
75-
queued = queue_count < queue_limit
86+
queued = queue_count + requested <= queue_limit
7687
7788
if queued
7889
then
79-
redis.call(""zadd"", @queue_key, timestamp, @unique_id)
90+
redis.call(""zadd"", @queue_key, unpack(addparams))
8091
end
8192
8293
end
8394
end
8495
8596
if allowed
8697
then
87-
redis.call(""hincrby"", @stats_key, 'total_successful', 1)
98+
redis.call(""hincrby"", @stats_key, 'total_successful', requested)
8899
else
89100
if queued == false and try_enqueue == 1
90101
then
91-
redis.call(""hincrby"", @stats_key, 'total_failed', 1)
102+
redis.call(""hincrby"", @stats_key, 'total_failed', requested)
92103
end
93104
end
94105
@@ -114,7 +125,7 @@ public RedisConcurrencyManager(
114125
StatsRateLimitKey = new RedisKey($"rl:{{{partitionKey}}}:stats");
115126
}
116127

117-
internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string requestId, bool tryEnqueue = false)
128+
internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string requestId, int permitCount, bool tryEnqueue = false)
118129
{
119130
var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
120131

@@ -132,6 +143,7 @@ internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string reques
132143
stats_key = StatsRateLimitKey,
133144
current_time = nowUnixTimeSeconds,
134145
unique_id = requestId,
146+
permit_count = permitCount
135147
});
136148

137149
var result = new RedisConcurrencyResponse();
@@ -147,7 +159,7 @@ internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string reques
147159
return result;
148160
}
149161

150-
internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqueue = false)
162+
internal RedisConcurrencyResponse TryAcquireLease(string requestId, int permitCount, bool tryEnqueue = false)
151163
{
152164
var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
153165

@@ -165,6 +177,7 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu
165177
stats_key = StatsRateLimitKey,
166178
current_time = nowUnixTimeSeconds,
167179
unique_id = requestId,
180+
permit_count = permitCount
168181
});
169182

170183
var result = new RedisConcurrencyResponse();
@@ -180,22 +193,41 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu
180193
return result;
181194
}
182195

183-
internal void ReleaseLease(string requestId)
196+
internal void ReleaseLease(string requestId, int permitCount)
184197
{
185198
var database = _connectionMultiplexer.GetDatabase();
186-
database.SortedSetRemove(RateLimitKey, requestId);
199+
200+
for (var i = 1; i <= permitCount; i++)
201+
{
202+
database.SortedSetRemove(RateLimitKey, $"{requestId}:{i}");
203+
}
187204
}
188205

189-
internal async Task ReleaseLeaseAsync(string requestId)
206+
internal Task ReleaseLeaseAsync(string requestId, int permitCount)
190207
{
191208
var database = _connectionMultiplexer.GetDatabase();
192-
await database.SortedSetRemoveAsync(RateLimitKey, requestId);
209+
var tasks = new List<Task>(permitCount);
210+
211+
for (var i = 1; i <= permitCount; i++)
212+
{
213+
tasks.Add(database.SortedSetRemoveAsync(RateLimitKey, $"{requestId}:{i}"));
214+
}
215+
216+
return Task.WhenAll(tasks);
193217
}
194218

195-
internal async Task ReleaseQueueLeaseAsync(string requestId)
219+
internal Task ReleaseQueueLeaseAsync(string requestId, int permitCount)
196220
{
197221
var database = _connectionMultiplexer.GetDatabase();
198-
await database.SortedSetRemoveAsync(QueueRateLimitKey, requestId);
222+
223+
var tasks = new List<Task>(permitCount);
224+
225+
for (var i = 1; i <= permitCount; i++)
226+
{
227+
tasks.Add(database.SortedSetRemoveAsync(QueueRateLimitKey, $"{requestId}:{i}"));
228+
}
229+
230+
return Task.WhenAll(tasks);
199231
}
200232

201233
internal RateLimiterStatistics? GetStatistics()

src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs

+12-8
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ protected override ValueTask<RateLimitLease> AcquireAsyncCore(int permitCount, C
7171
throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit));
7272
}
7373

74-
return AcquireAsyncCoreInternal(cancellationToken);
74+
return AcquireAsyncCoreInternal(permitCount, cancellationToken);
7575
}
7676

7777
protected override RateLimitLease AttemptAcquireCore(int permitCount)
@@ -85,9 +85,10 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount)
8585
{
8686
Limit = _options.PermitLimit,
8787
RequestId = Guid.NewGuid().ToString(),
88+
PermitCount = permitCount
8889
};
8990

90-
var response = _redisManager.TryAcquireLease(leaseContext.RequestId);
91+
var response = _redisManager.TryAcquireLease(leaseContext.RequestId, permitCount);
9192

9293
leaseContext.Count = response.Count;
9394

@@ -99,15 +100,16 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount)
99100
return new ConcurrencyLease(isAcquired: false, this, leaseContext);
100101
}
101102

102-
private async ValueTask<RateLimitLease> AcquireAsyncCoreInternal(CancellationToken cancellationToken)
103+
private async ValueTask<RateLimitLease> AcquireAsyncCoreInternal(int permitCount, CancellationToken cancellationToken)
103104
{
104105
var leaseContext = new ConcurencyLeaseContext
105106
{
106107
Limit = _options.PermitLimit,
107108
RequestId = Guid.NewGuid().ToString(),
109+
PermitCount = permitCount
108110
};
109111

110-
var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, tryEnqueue: true);
112+
var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, permitCount, tryEnqueue: true);
111113

112114
leaseContext.Count = response.Count;
113115

@@ -148,7 +150,7 @@ private void Release(ConcurencyLeaseContext leaseContext)
148150
{
149151
if (leaseContext.RequestId is null) return;
150152

151-
_redisManager.ReleaseLease(leaseContext.RequestId);
153+
_redisManager.ReleaseLease(leaseContext.RequestId, leaseContext.PermitCount);
152154
}
153155

154156
private async Task StartDequeueTimerAsync(PeriodicTimer periodicTimer)
@@ -170,7 +172,7 @@ private async Task TryDequeueRequestsAsync()
170172
try
171173
{
172174
// The request was canceled while in the pending queue
173-
await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!);
175+
await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount);
174176
}
175177
finally
176178
{
@@ -182,7 +184,7 @@ private async Task TryDequeueRequestsAsync()
182184
continue;
183185
}
184186

185-
var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!);
187+
var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount);
186188

187189
request.LeaseContext.Count = response.Count;
188190

@@ -195,7 +197,7 @@ private async Task TryDequeueRequestsAsync()
195197
if (request.TaskCompletionSource?.TrySetResult(pendingLease) == false)
196198
{
197199
// The request was canceled after we acquired the lease
198-
await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!);
200+
await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount);
199201
}
200202
}
201203
finally
@@ -257,6 +259,8 @@ private sealed class ConcurencyLeaseContext
257259
public long Count { get; set; }
258260

259261
public long Limit { get; set; }
262+
263+
public int PermitCount { get; set; }
260264
}
261265

262266
private sealed class ConcurrencyLease : RateLimitLease

test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs

+53
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,59 @@ public void CanAcquireResource()
120120
lease.Dispose();
121121
}
122122

123+
[Fact]
124+
public async Task SupportsPermitCountFlag_NoQueue()
125+
{
126+
using var limiter = new RedisConcurrencyRateLimiter<string>(
127+
"Test_SupportsPermitCountFlag_NoQueue_Concurrency",
128+
new RedisConcurrencyRateLimiterOptions
129+
{
130+
PermitLimit = 2,
131+
QueueLimit = 0,
132+
ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory,
133+
});
134+
135+
var lease1 = await limiter.AcquireAsync(permitCount: 2);
136+
var lease2 = await limiter.AcquireAsync(permitCount: 1);
137+
138+
Assert.True(lease1.IsAcquired);
139+
Assert.False(lease2.IsAcquired);
140+
141+
lease1.Dispose();
142+
lease2.Dispose();
143+
}
144+
145+
[Fact]
146+
public async Task SupportsPermitCountFlag_WithQueue()
147+
{
148+
using var limiter = new RedisConcurrencyRateLimiter<string>(
149+
"Test_SupportsPermitCountFlag_WithQueue_Concurrency",
150+
new RedisConcurrencyRateLimiterOptions
151+
{
152+
PermitLimit = 2,
153+
QueueLimit = 2,
154+
ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory,
155+
});
156+
157+
var lease1 = await limiter.AcquireAsync(permitCount: 2);
158+
Assert.True(lease1.IsAcquired);
159+
160+
var wait2 = limiter.AcquireAsync(permitCount: 2);
161+
var wait3 = limiter.AcquireAsync(permitCount: 1);
162+
163+
await Task.Delay(1000);
164+
lease1.Dispose();
165+
166+
var lease2 = await wait2;
167+
Assert.True(lease2.IsAcquired);
168+
169+
var lease3 = await wait3;
170+
Assert.False(lease3.IsAcquired);
171+
172+
lease2.Dispose();
173+
lease3.Dispose();
174+
}
175+
123176
[Fact]
124177
public async Task CanAcquireResourceAsyncQueuesAndGrabsOldest()
125178
{

0 commit comments

Comments
 (0)