From 7b4bc044e4378079b5b06a437fab8f55c6adfb75 Mon Sep 17 00:00:00 2001 From: Manuel Spezzani Date: Fri, 7 Apr 2023 11:35:52 +0200 Subject: [PATCH 1/4] RedisSlidingWindowRateLimiter now supports permitCount parameter --- .../RedisSlidingWindowManager.cs | 18 ++++++++++++++---- .../RedisSlidingWindowRateLimiter.cs | 8 ++++---- .../UnitTests/SlidingWindowUnitTests.cs | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs index c6dcf57..de52955 100644 --- a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs +++ b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs @@ -16,16 +16,24 @@ internal class RedisSlidingWindowManager @"local limit = tonumber(@permit_limit) local timestamp = tonumber(@current_time) local window = tonumber(@window) + local requested = tonumber(@permit_count) + + local zaddparams = {} + for i=1,requested do + local index = i*2 + zaddparams[index-1]=timestamp + zaddparams[index]=@unique_id..':'..tostring(i) + end -- remove all requests outside current window redis.call(""zremrangebyscore"", @rate_limit_key, '-inf', timestamp - window) local count = redis.call(""zcard"", @rate_limit_key) - local allowed = count < limit + local allowed = count + requested <= limit if allowed then - redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id) + redis.call(""zadd"", @rate_limit_key, unpack(zaddparams)) end redis.call(""expireat"", @rate_limit_key, timestamp + window + 1) @@ -57,7 +65,7 @@ public RedisSlidingWindowManager( StatsRateLimitKey = new RedisKey($"rl:{{{partitionKey}}}:stats"); } - internal async Task TryAcquireLeaseAsync(string requestId) + internal async Task TryAcquireLeaseAsync(string requestId, int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -74,6 +82,7 @@ internal async Task TryAcquireLeaseAsync(string requ stats_key = StatsRateLimitKey, current_time = nowUnixTimeSeconds, unique_id = requestId, + permit_count = permitCount }); var result = new RedisSlidingWindowResponse(); @@ -87,7 +96,7 @@ internal async Task TryAcquireLeaseAsync(string requ return result; } - internal RedisSlidingWindowResponse TryAcquireLease(string requestId) + internal RedisSlidingWindowResponse TryAcquireLease(string requestId, int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -104,6 +113,7 @@ internal RedisSlidingWindowResponse TryAcquireLease(string requestId) stats_key = StatsRateLimitKey, current_time = nowUnixTimeSeconds, unique_id = requestId, + permit_count = permitCount }); var result = new RedisSlidingWindowResponse(); diff --git a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs index 5d3eae3..df1dc18 100644 --- a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs +++ b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs @@ -55,7 +55,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -72,7 +72,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) RequestId = Guid.NewGuid().ToString(), }; - var response = _redisManager.TryAcquireLease(leaseContext.RequestId); + var response = _redisManager.TryAcquireLease(leaseContext.RequestId, permitCount); leaseContext.Count = response.Count; leaseContext.Allowed = response.Allowed; @@ -85,7 +85,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return new SlidingWindowLease(isAcquired: false, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new SlidingWindowLeaseContext { @@ -94,7 +94,7 @@ private async ValueTask AcquireAsyncCoreInternal() RequestId = Guid.NewGuid().ToString(), }; - var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId); + var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, permitCount); leaseContext.Count = response.Count; leaseContext.Allowed = response.Allowed; diff --git a/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs index b9858e8..4de556d 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs @@ -106,5 +106,24 @@ public async Task CanAcquireAsyncResource() stats = limiter.GetStatistics()!; Assert.Equal(0, stats.CurrentAvailablePermits); } + + [Fact] + public async Task SupportsPermitCountFlag() + { + using var limiter = new RedisSlidingWindowRateLimiter( + "Test_SupportsPermitCountFlag_SW", + new RedisSlidingWindowRateLimiterOptions + { + PermitLimit = 3, + Window = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(permitCount: 3); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(permitCount: 1); + Assert.False(lease2.IsAcquired); + } } } From 0b50d5ff6ff41cffe9320ca3f79141559257a513 Mon Sep 17 00:00:00 2001 From: Manuel Spezzani Date: Fri, 7 Apr 2023 11:34:54 +0200 Subject: [PATCH 2/4] RedisFixedWindowRateLimiter now supports permitCount parameter (cherry picked from commit bb6baed024a4595a639085db09df954c90b064bf) --- .../FixedWindow/RedisFixedWindowManager.cs | 30 ++++++++++++++----- .../RedisFixedWindowRateLimiter.cs | 24 +++++---------- .../UnitTests/FixedWindowUnitTests.cs | 22 ++++++++++++++ 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs index eb6c6be..35fa476 100644 --- a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs +++ b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs @@ -13,6 +13,9 @@ internal class RedisFixedWindowManager private static readonly LuaScript _redisScript = LuaScript.Prepare( @"local expires_at = tonumber(redis.call(""get"", @expires_at_key)) + local current = tonumber(redis.call(""get"", @rate_limit_key)) + local requested = tonumber(@increment_amount) + local limit = tonumber(@permit_limit) if not expires_at or expires_at < tonumber(@current_time) then -- this is either a brand new window, @@ -27,13 +30,19 @@ internal class RedisFixedWindowManager redis.call(""expireat"", @expires_at_key, @next_expires_at + 1) -- since the database was updated, return the new value expires_at = @next_expires_at + current = 0 end - -- now that the window either already exists or it was freshly initialized, - -- increment the counter(`incrby` returns a number) - local current = redis.call(""incrby"", @rate_limit_key, @increment_amount) + local allowed = current + requested <= limit - return { current, expires_at }"); + if allowed + then + -- now that the window either already exists or it was freshly initialized, + -- increment the counter(`incrby` returns a number) + current = redis.call(""incrby"", @rate_limit_key, @increment_amount) + end + + return { current, expires_at, allowed }"); public RedisFixedWindowManager( string partitionKey, @@ -46,7 +55,7 @@ public RedisFixedWindowManager( RateLimitExpireKey = new RedisKey($"rl:{{{partitionKey}}}:exp"); } - internal async Task TryAcquireLeaseAsync() + internal async Task TryAcquireLeaseAsync(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -59,9 +68,10 @@ internal async Task TryAcquireLeaseAsync() { rate_limit_key = RateLimitKey, expires_at_key = RateLimitExpireKey, + permit_limit = _options.PermitLimit, next_expires_at = now.Add(_options.Window).ToUnixTimeSeconds(), current_time = nowUnixTimeSeconds, - increment_amount = 1D, + increment_amount = permitCount, }); var result = new RedisFixedWindowResponse(); @@ -70,13 +80,14 @@ internal async Task TryAcquireLeaseAsync() { result.Count = (long)response[0]; result.ExpiresAt = (long)response[1]; + result.Allowed = (bool)response[2]; result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds); } return result; } - internal RedisFixedWindowResponse TryAcquireLease() + internal RedisFixedWindowResponse TryAcquireLease(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -89,9 +100,10 @@ internal RedisFixedWindowResponse TryAcquireLease() { rate_limit_key = RateLimitKey, expires_at_key = RateLimitExpireKey, + permit_limit = _options.PermitLimit, next_expires_at = now.Add(_options.Window).ToUnixTimeSeconds(), current_time = nowUnixTimeSeconds, - increment_amount = 1D, + increment_amount = permitCount, }); var result = new RedisFixedWindowResponse(); @@ -100,6 +112,7 @@ internal RedisFixedWindowResponse TryAcquireLease() { result.Count = (long)response[0]; result.ExpiresAt = (long)response[1]; + result.Allowed = (bool)response[2]; result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds); } @@ -112,5 +125,6 @@ internal class RedisFixedWindowResponse internal long ExpiresAt { get; set; } internal TimeSpan RetryAfter { get; set; } internal long Count { get; set; } + internal bool Allowed { get; set; } } } diff --git a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs index 45734b6..0ba4b39 100644 --- a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs +++ b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs @@ -55,7 +55,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -71,21 +71,16 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) Window = _options.Window, }; - var response = _redisManager.TryAcquireLease(); + var response = _redisManager.TryAcquireLease(permitCount); leaseContext.Count = response.Count; leaseContext.RetryAfter = response.RetryAfter; leaseContext.ExpiresAt = DateTimeOffset.FromUnixTimeSeconds(response.ExpiresAt); - - if (leaseContext.Count > _options.PermitLimit) - { - return new FixedWindowLease(isAcquired: false, leaseContext); - } - - return new FixedWindowLease(isAcquired: true, leaseContext); + + return new FixedWindowLease(isAcquired: response.Allowed, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new FixedWindowLeaseContext { @@ -93,17 +88,12 @@ private async ValueTask AcquireAsyncCoreInternal() Window = _options.Window, }; - var response = await _redisManager.TryAcquireLeaseAsync(); + var response = await _redisManager.TryAcquireLeaseAsync(permitCount); leaseContext.Count = response.Count; leaseContext.RetryAfter = response.RetryAfter; - if (leaseContext.Count > _options.PermitLimit) - { - return new FixedWindowLease(isAcquired: false, leaseContext); - } - - return new FixedWindowLease(isAcquired: true, leaseContext); + return new FixedWindowLease(isAcquired: response.Allowed, leaseContext); } private sealed class FixedWindowLeaseContext diff --git a/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs index 2c293ff..8c243f7 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs @@ -86,5 +86,27 @@ public async Task CanAcquireAsyncResource() using var lease2 = await limiter.AcquireAsync(); Assert.False(lease2.IsAcquired); } + + [Fact] + public async Task SupportsPermitCountFlag() + { + using var limiter = new RedisFixedWindowRateLimiter( + "Test_SupportsPermitCountFlag_FW", + new RedisFixedWindowRateLimiterOptions + { + PermitLimit = 5, + Window = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(permitCount: 3); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(permitCount: 3); + Assert.False(lease2.IsAcquired); + + using var lease3 = await limiter.AcquireAsync(permitCount: 2); + Assert.True(lease3.IsAcquired); + } } } From 04cad092ff09cfbb23b0eb3d7852fddd142edee3 Mon Sep 17 00:00:00 2001 From: Manuel Spezzani Date: Fri, 7 Apr 2023 12:06:31 +0200 Subject: [PATCH 3/4] RedisTokenBucketRateLimiter now supports permitCount parameter --- .../TokenBucket/RedisTokenBucketManager.cs | 8 +++---- .../RedisTokenBucketRateLimiter.cs | 8 +++---- .../UnitTests/TokenBucketUnitTests.cs | 23 +++++++++++++++++++ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs index e99458a..c94b0c7 100644 --- a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs +++ b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs @@ -56,7 +56,7 @@ public RedisTokenBucketManager( RateLimitTimestampKey = new RedisKey($"rl:{{{partitionKey}}}:ts"); } - internal async Task TryAcquireLeaseAsync() + internal async Task TryAcquireLeaseAsync(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -73,7 +73,7 @@ internal async Task TryAcquireLeaseAsync() tokens_per_period = _options.TokensPerPeriod, token_limit = _options.TokenLimit, replenish_period = _options.ReplenishmentPeriod.TotalSeconds, - permit_count = 1D, + permit_count = permitCount, }); var result = new RedisTokenBucketResponse(); @@ -87,7 +87,7 @@ internal async Task TryAcquireLeaseAsync() return result; } - internal RedisTokenBucketResponse TryAcquireLease() + internal RedisTokenBucketResponse TryAcquireLease(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -104,7 +104,7 @@ internal RedisTokenBucketResponse TryAcquireLease() tokens_per_period = _options.TokensPerPeriod, token_limit = _options.TokenLimit, replenish_period = _options.ReplenishmentPeriod.TotalSeconds, - permit_count = 1D, + permit_count = permitCount, }); var result = new RedisTokenBucketResponse(); diff --git a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs index efc00ef..4dd3c0d 100644 --- a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs +++ b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs @@ -60,7 +60,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.TokenLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -75,7 +75,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) Limit = _options.TokenLimit, }; - var response = _redisManager.TryAcquireLease(); + var response = _redisManager.TryAcquireLease(permitCount); leaseContext.Allowed = response.Allowed; leaseContext.Count = response.Count; @@ -88,14 +88,14 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return new TokenBucketLease(isAcquired: false, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new TokenBucketLeaseContext { Limit = _options.TokenLimit, }; - var response = await _redisManager.TryAcquireLeaseAsync(); + var response = await _redisManager.TryAcquireLeaseAsync(permitCount); leaseContext.Allowed = response.Allowed; leaseContext.Count = response.Count; diff --git a/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs index 5140e11..df91a71 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs @@ -107,5 +107,28 @@ public async Task CanAcquireAsyncResource() using var lease2 = await limiter.AcquireAsync(); Assert.False(lease2.IsAcquired); } + + [Fact] + public async Task SupportsPermitCountFlag() + { + using var limiter = new RedisTokenBucketRateLimiter( + "Test_SupportsPermitCountFlag_TB", + new RedisTokenBucketRateLimiterOptions + { + TokenLimit = 5, + TokensPerPeriod = 5, + ReplenishmentPeriod = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(4); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(3); + Assert.False(lease2.IsAcquired); + + using var lease3 = await limiter.AcquireAsync(1); + Assert.True(lease3.IsAcquired); + } } } From 484237eb67ebe055ef6c99cfddd2bc6be1d7422e Mon Sep 17 00:00:00 2001 From: Manuel Spezzani Date: Fri, 7 Apr 2023 16:25:50 +0200 Subject: [PATCH 4/4] RedisConcurrencyRateLimiter now supports permitCount parameter --- .../Concurrency/RedisConcurrencyManager.cs | 66 ++++++++++++++----- .../RedisConcurrencyRateLimiter.cs | 20 +++--- .../UnitTests/ConcurrencyUnitTests.cs | 53 +++++++++++++++ 3 files changed, 114 insertions(+), 25 deletions(-) diff --git a/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs b/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs index 3c4d36e..7e942cc 100644 --- a/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs +++ b/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs @@ -1,5 +1,6 @@ using StackExchange.Redis; using System; +using System.Collections.Generic; using System.Threading.RateLimiting; using System.Threading.Tasks; @@ -18,6 +19,7 @@ internal class RedisConcurrencyManager local queue_limit = tonumber(@queue_limit) local try_enqueue = tonumber(@try_enqueue) local timestamp = tonumber(@current_time) + local requested = tonumber(@permit_count) -- max seconds it takes to complete a request local ttl = 60 @@ -29,10 +31,19 @@ internal class RedisConcurrencyManager end local count = redis.call(""zcard"", @rate_limit_key) - local allowed = count < limit + local allowed = count + requested <= limit local queued = false local queue_count = 0 + local addparams = {} + local remparams = {} + for i=1,requested do + local index = i*2 + addparams[index-1]=timestamp + addparams[index]=@unique_id..':'..tostring(i) + remparams[i]=addparams[index] + end + if allowed then @@ -45,23 +56,23 @@ internal class RedisConcurrencyManager if queue_count == 0 or try_enqueue == 0 then - redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id) + redis.call(""zadd"", @rate_limit_key, unpack(addparams)) if queue_limit > 0 then -- remove from pending queue - redis.call(""zrem"", @queue_key, @unique_id) + redis.call(""zrem"", @queue_key, unpack(remparams)) end else -- queue the current request next in line if we have any requests in the pending queue allowed = false - queued = queue_count + count < limit + queue_limit + queued = queue_count + count + requested <= limit + queue_limit if queued then - redis.call(""zadd"", @queue_key, timestamp, @unique_id) + redis.call(""zadd"", @queue_key, unpack(addparams)) end end @@ -72,11 +83,11 @@ internal class RedisConcurrencyManager then queue_count = redis.call(""zcard"", @queue_key) - queued = queue_count < queue_limit + queued = queue_count + requested <= queue_limit if queued then - redis.call(""zadd"", @queue_key, timestamp, @unique_id) + redis.call(""zadd"", @queue_key, unpack(addparams)) end end @@ -84,11 +95,11 @@ internal class RedisConcurrencyManager if allowed then - redis.call(""hincrby"", @stats_key, 'total_successful', 1) + redis.call(""hincrby"", @stats_key, 'total_successful', requested) else if queued == false and try_enqueue == 1 then - redis.call(""hincrby"", @stats_key, 'total_failed', 1) + redis.call(""hincrby"", @stats_key, 'total_failed', requested) end end @@ -114,7 +125,7 @@ public RedisConcurrencyManager( StatsRateLimitKey = new RedisKey($"rl:{{{partitionKey}}}:stats"); } - internal async Task TryAcquireLeaseAsync(string requestId, bool tryEnqueue = false) + internal async Task TryAcquireLeaseAsync(string requestId, int permitCount, bool tryEnqueue = false) { var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -132,6 +143,7 @@ internal async Task TryAcquireLeaseAsync(string reques stats_key = StatsRateLimitKey, current_time = nowUnixTimeSeconds, unique_id = requestId, + permit_count = permitCount }); var result = new RedisConcurrencyResponse(); @@ -147,7 +159,7 @@ internal async Task TryAcquireLeaseAsync(string reques return result; } - internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqueue = false) + internal RedisConcurrencyResponse TryAcquireLease(string requestId, int permitCount, bool tryEnqueue = false) { var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -165,6 +177,7 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu stats_key = StatsRateLimitKey, current_time = nowUnixTimeSeconds, unique_id = requestId, + permit_count = permitCount }); var result = new RedisConcurrencyResponse(); @@ -180,22 +193,41 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu return result; } - internal void ReleaseLease(string requestId) + internal void ReleaseLease(string requestId, int permitCount) { var database = _connectionMultiplexer.GetDatabase(); - database.SortedSetRemove(RateLimitKey, requestId); + + for (var i = 1; i <= permitCount; i++) + { + database.SortedSetRemove(RateLimitKey, $"{requestId}:{i}"); + } } - internal async Task ReleaseLeaseAsync(string requestId) + internal Task ReleaseLeaseAsync(string requestId, int permitCount) { var database = _connectionMultiplexer.GetDatabase(); - await database.SortedSetRemoveAsync(RateLimitKey, requestId); + var tasks = new List(permitCount); + + for (var i = 1; i <= permitCount; i++) + { + tasks.Add(database.SortedSetRemoveAsync(RateLimitKey, $"{requestId}:{i}")); + } + + return Task.WhenAll(tasks); } - internal async Task ReleaseQueueLeaseAsync(string requestId) + internal Task ReleaseQueueLeaseAsync(string requestId, int permitCount) { var database = _connectionMultiplexer.GetDatabase(); - await database.SortedSetRemoveAsync(QueueRateLimitKey, requestId); + + var tasks = new List(permitCount); + + for (var i = 1; i <= permitCount; i++) + { + tasks.Add(database.SortedSetRemoveAsync(QueueRateLimitKey, $"{requestId}:{i}")); + } + + return Task.WhenAll(tasks); } internal RateLimiterStatistics? GetStatistics() diff --git a/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs b/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs index 6c7479a..f7c91df 100644 --- a/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs +++ b/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs @@ -71,7 +71,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(cancellationToken); + return AcquireAsyncCoreInternal(permitCount, cancellationToken); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -85,9 +85,10 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) { Limit = _options.PermitLimit, RequestId = Guid.NewGuid().ToString(), + PermitCount = permitCount }; - var response = _redisManager.TryAcquireLease(leaseContext.RequestId); + var response = _redisManager.TryAcquireLease(leaseContext.RequestId, permitCount); leaseContext.Count = response.Count; @@ -99,15 +100,16 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return new ConcurrencyLease(isAcquired: false, this, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal(CancellationToken cancellationToken) + private async ValueTask AcquireAsyncCoreInternal(int permitCount, CancellationToken cancellationToken) { var leaseContext = new ConcurencyLeaseContext { Limit = _options.PermitLimit, RequestId = Guid.NewGuid().ToString(), + PermitCount = permitCount }; - var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, tryEnqueue: true); + var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, permitCount, tryEnqueue: true); leaseContext.Count = response.Count; @@ -148,7 +150,7 @@ private void Release(ConcurencyLeaseContext leaseContext) { if (leaseContext.RequestId is null) return; - _redisManager.ReleaseLease(leaseContext.RequestId); + _redisManager.ReleaseLease(leaseContext.RequestId, leaseContext.PermitCount); } private async Task StartDequeueTimerAsync(PeriodicTimer periodicTimer) @@ -170,7 +172,7 @@ private async Task TryDequeueRequestsAsync() try { // The request was canceled while in the pending queue - await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!); + await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount); } finally { @@ -182,7 +184,7 @@ private async Task TryDequeueRequestsAsync() continue; } - var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!); + var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount); request.LeaseContext.Count = response.Count; @@ -195,7 +197,7 @@ private async Task TryDequeueRequestsAsync() if (request.TaskCompletionSource?.TrySetResult(pendingLease) == false) { // The request was canceled after we acquired the lease - await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!); + await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount); } } finally @@ -257,6 +259,8 @@ private sealed class ConcurencyLeaseContext public long Count { get; set; } public long Limit { get; set; } + + public int PermitCount { get; set; } } private sealed class ConcurrencyLease : RateLimitLease diff --git a/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs index 4026684..78a2c79 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs @@ -120,6 +120,59 @@ public void CanAcquireResource() lease.Dispose(); } + [Fact] + public async Task SupportsPermitCountFlag_NoQueue() + { + using var limiter = new RedisConcurrencyRateLimiter( + "Test_SupportsPermitCountFlag_NoQueue_Concurrency", + new RedisConcurrencyRateLimiterOptions + { + PermitLimit = 2, + QueueLimit = 0, + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + var lease1 = await limiter.AcquireAsync(permitCount: 2); + var lease2 = await limiter.AcquireAsync(permitCount: 1); + + Assert.True(lease1.IsAcquired); + Assert.False(lease2.IsAcquired); + + lease1.Dispose(); + lease2.Dispose(); + } + + [Fact] + public async Task SupportsPermitCountFlag_WithQueue() + { + using var limiter = new RedisConcurrencyRateLimiter( + "Test_SupportsPermitCountFlag_WithQueue_Concurrency", + new RedisConcurrencyRateLimiterOptions + { + PermitLimit = 2, + QueueLimit = 2, + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + var lease1 = await limiter.AcquireAsync(permitCount: 2); + Assert.True(lease1.IsAcquired); + + var wait2 = limiter.AcquireAsync(permitCount: 2); + var wait3 = limiter.AcquireAsync(permitCount: 1); + + await Task.Delay(1000); + lease1.Dispose(); + + var lease2 = await wait2; + Assert.True(lease2.IsAcquired); + + var lease3 = await wait3; + Assert.False(lease3.IsAcquired); + + lease2.Dispose(); + lease3.Dispose(); + } + [Fact] public async Task CanAcquireResourceAsyncQueuesAndGrabsOldest() {