diff --git a/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs b/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs index d8bb2b1..c619dbb 100644 --- a/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs +++ b/src/RedisRateLimiting/Concurrency/RedisConcurrencyManager.cs @@ -1,5 +1,6 @@ using StackExchange.Redis; using System; +using System.Collections.Generic; using System.Threading.RateLimiting; using System.Threading.Tasks; @@ -18,6 +19,7 @@ internal class RedisConcurrencyManager local queue_limit = tonumber(@queue_limit) local try_enqueue = tonumber(@try_enqueue) local timestamp = tonumber(@current_time) + local requested = tonumber(@permit_count) -- max seconds it takes to complete a request local ttl = 60 @@ -29,10 +31,19 @@ internal class RedisConcurrencyManager end local count = redis.call(""zcard"", @rate_limit_key) - local allowed = count < limit + local allowed = count + requested <= limit local queued = false local queue_count = 0 + local addparams = {} + local remparams = {} + for i=1,requested do + local index = i*2 + addparams[index-1]=timestamp + addparams[index]=@unique_id..':'..tostring(i) + remparams[i]=addparams[index] + end + if allowed then @@ -45,23 +56,23 @@ internal class RedisConcurrencyManager if queue_count == 0 or try_enqueue == 0 then - redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id) + redis.call(""zadd"", @rate_limit_key, unpack(addparams)) if queue_limit > 0 then -- remove from pending queue - redis.call(""zrem"", @queue_key, @unique_id) + redis.call(""zrem"", @queue_key, unpack(remparams)) end else -- queue the current request next in line if we have any requests in the pending queue allowed = false - queued = queue_count + count < limit + queue_limit + queued = queue_count + count + requested <= limit + queue_limit if queued then - redis.call(""zadd"", @queue_key, timestamp, @unique_id) + redis.call(""zadd"", @queue_key, unpack(addparams)) end end @@ -72,11 +83,11 @@ internal class RedisConcurrencyManager then queue_count = redis.call(""zcard"", @queue_key) - queued = queue_count < queue_limit + queued = queue_count + requested <= queue_limit if queued then - redis.call(""zadd"", @queue_key, timestamp, @unique_id) + redis.call(""zadd"", @queue_key, unpack(addparams)) end end @@ -84,11 +95,11 @@ internal class RedisConcurrencyManager if allowed then - redis.call(""hincrby"", @stats_key, 'total_successful', 1) + redis.call(""hincrby"", @stats_key, 'total_successful', requested) else if queued == false and try_enqueue == 1 then - redis.call(""hincrby"", @stats_key, 'total_failed', 1) + redis.call(""hincrby"", @stats_key, 'total_failed', requested) end end @@ -114,7 +125,7 @@ public RedisConcurrencyManager( StatsRateLimitKey = new RedisKey($"rl:{{{partitionKey}}}:stats"); } - internal async Task TryAcquireLeaseAsync(string requestId, bool tryEnqueue = false) + internal async Task TryAcquireLeaseAsync(string requestId, int permitCount, bool tryEnqueue = false) { var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -129,6 +140,7 @@ internal async Task TryAcquireLeaseAsync(string reques stats_key = StatsRateLimitKey, permit_limit = (RedisValue)_options.PermitLimit, try_enqueue = (RedisValue)(tryEnqueue ? 1 : 0), + permit_count = (RedisValue)permitCount, queue_limit = (RedisValue)_options.QueueLimit, current_time = (RedisValue)nowUnixTimeSeconds, unique_id = (RedisValue)requestId, @@ -147,7 +159,7 @@ internal async Task TryAcquireLeaseAsync(string reques return result; } - internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqueue = false) + internal RedisConcurrencyResponse TryAcquireLease(string requestId, int permitCount, bool tryEnqueue = false) { var nowUnixTimeSeconds = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -162,6 +174,7 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu stats_key = StatsRateLimitKey, permit_limit = (RedisValue)_options.PermitLimit, try_enqueue = (RedisValue)(tryEnqueue ? 1 : 0), + permit_count = (RedisValue)permitCount, queue_limit = (RedisValue)_options.QueueLimit, current_time = (RedisValue)nowUnixTimeSeconds, unique_id = (RedisValue)requestId, @@ -180,22 +193,41 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu return result; } - internal void ReleaseLease(string requestId) + internal void ReleaseLease(string requestId, int permitCount) { var database = _connectionMultiplexer.GetDatabase(); - database.SortedSetRemove(RateLimitKey, requestId); + + for (var i = 1; i <= permitCount; i++) + { + database.SortedSetRemove(RateLimitKey, $"{requestId}:{i}"); + } } - internal async Task ReleaseLeaseAsync(string requestId) + internal Task ReleaseLeaseAsync(string requestId, int permitCount) { var database = _connectionMultiplexer.GetDatabase(); - await database.SortedSetRemoveAsync(RateLimitKey, requestId); + var tasks = new List(permitCount); + + for (var i = 1; i <= permitCount; i++) + { + tasks.Add(database.SortedSetRemoveAsync(RateLimitKey, $"{requestId}:{i}")); + } + + return Task.WhenAll(tasks); } - internal async Task ReleaseQueueLeaseAsync(string requestId) + internal Task ReleaseQueueLeaseAsync(string requestId, int permitCount) { var database = _connectionMultiplexer.GetDatabase(); - await database.SortedSetRemoveAsync(QueueRateLimitKey, requestId); + + var tasks = new List(permitCount); + + for (var i = 1; i <= permitCount; i++) + { + tasks.Add(database.SortedSetRemoveAsync(QueueRateLimitKey, $"{requestId}:{i}")); + } + + return Task.WhenAll(tasks); } internal RateLimiterStatistics? GetStatistics() diff --git a/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs b/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs index e2b162c..7e72c3f 100644 --- a/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs +++ b/src/RedisRateLimiting/Concurrency/RedisConcurrencyRateLimiter.cs @@ -71,7 +71,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(cancellationToken); + return AcquireAsyncCoreInternal(permitCount, cancellationToken); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -85,9 +85,10 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) { Limit = _options.PermitLimit, RequestId = Guid.NewGuid().ToString(), + PermitCount = permitCount }; - var response = _redisManager.TryAcquireLease(leaseContext.RequestId); + var response = _redisManager.TryAcquireLease(leaseContext.RequestId, permitCount); leaseContext.Count = response.Count; @@ -99,15 +100,16 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return new ConcurrencyLease(isAcquired: false, this, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal(CancellationToken cancellationToken) + private async ValueTask AcquireAsyncCoreInternal(int permitCount, CancellationToken cancellationToken) { var leaseContext = new ConcurencyLeaseContext { Limit = _options.PermitLimit, RequestId = Guid.NewGuid().ToString(), + PermitCount = permitCount }; - var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, tryEnqueue: true); + var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, permitCount, tryEnqueue: true); leaseContext.Count = response.Count; @@ -148,7 +150,7 @@ private void Release(ConcurencyLeaseContext leaseContext) { if (leaseContext.RequestId is null) return; - _redisManager.ReleaseLease(leaseContext.RequestId); + _redisManager.ReleaseLease(leaseContext.RequestId, leaseContext.PermitCount); } private async Task StartDequeueTimerAsync(PeriodicTimer periodicTimer) @@ -170,7 +172,7 @@ private async Task TryDequeueRequestsAsync() try { // The request was canceled while in the pending queue - await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!); + await _redisManager.ReleaseQueueLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount); } finally { @@ -182,7 +184,7 @@ private async Task TryDequeueRequestsAsync() continue; } - var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!); + var response = await _redisManager.TryAcquireLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount); request.LeaseContext.Count = response.Count; @@ -195,7 +197,7 @@ private async Task TryDequeueRequestsAsync() if (request.TaskCompletionSource?.TrySetResult(pendingLease) == false) { // The request was canceled after we acquired the lease - await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!); + await _redisManager.ReleaseLeaseAsync(request.LeaseContext!.RequestId!, request.LeaseContext!.PermitCount); } } finally @@ -256,6 +258,8 @@ private sealed class ConcurencyLeaseContext public long Count { get; set; } public long Limit { get; set; } + + public int PermitCount { get; set; } } private sealed class ConcurrencyLease : RateLimitLease diff --git a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs index 37de775..6f12ce5 100644 --- a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs +++ b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs @@ -13,6 +13,9 @@ internal class RedisFixedWindowManager private static readonly LuaScript Script = LuaScript.Prepare( @"local expires_at = tonumber(redis.call(""get"", @expires_at_key)) + local current = tonumber(redis.call(""get"", @rate_limit_key)) + local requested = tonumber(@increment_amount) + local limit = tonumber(@permit_limit) if not expires_at or expires_at < tonumber(@current_time) then -- this is either a brand new window, @@ -27,13 +30,19 @@ internal class RedisFixedWindowManager redis.call(""expireat"", @expires_at_key, @next_expires_at + 1) -- since the database was updated, return the new value expires_at = @next_expires_at + current = 0 end - -- now that the window either already exists or it was freshly initialized, - -- increment the counter(`incrby` returns a number) - local current = redis.call(""incrby"", @rate_limit_key, @increment_amount) + local allowed = current + requested <= limit - return { current, expires_at }"); + if allowed + then + -- now that the window either already exists or it was freshly initialized, + -- increment the counter(`incrby` returns a number) + current = redis.call(""incrby"", @rate_limit_key, @increment_amount) + end + + return { current, expires_at, allowed }"); public RedisFixedWindowManager( string partitionKey, @@ -46,7 +55,7 @@ public RedisFixedWindowManager( RateLimitExpireKey = new RedisKey($"rl:{{{partitionKey}}}:exp"); } - internal async Task TryAcquireLeaseAsync() + internal async Task TryAcquireLeaseAsync(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -59,9 +68,10 @@ internal async Task TryAcquireLeaseAsync() { rate_limit_key = RateLimitKey, expires_at_key = RateLimitExpireKey, + permit_limit = _options.PermitLimit, next_expires_at = (RedisValue)now.Add(_options.Window).ToUnixTimeSeconds(), current_time = (RedisValue)nowUnixTimeSeconds, - increment_amount = (RedisValue)1D, + increment_amount = (RedisValue)permitCount, }); var result = new RedisFixedWindowResponse(); @@ -70,13 +80,14 @@ internal async Task TryAcquireLeaseAsync() { result.Count = (long)response[0]; result.ExpiresAt = (long)response[1]; + result.Allowed = (bool)response[2]; result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds); } return result; } - internal RedisFixedWindowResponse TryAcquireLease() + internal RedisFixedWindowResponse TryAcquireLease(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -89,9 +100,10 @@ internal RedisFixedWindowResponse TryAcquireLease() { rate_limit_key = RateLimitKey, expires_at_key = RateLimitExpireKey, + permit_limit = _options.PermitLimit, next_expires_at = (RedisValue)now.Add(_options.Window).ToUnixTimeSeconds(), current_time = (RedisValue)nowUnixTimeSeconds, - increment_amount = (RedisValue)1D, + increment_amount = (RedisValue)permitCount, }); var result = new RedisFixedWindowResponse(); @@ -100,6 +112,7 @@ internal RedisFixedWindowResponse TryAcquireLease() { result.Count = (long)response[0]; result.ExpiresAt = (long)response[1]; + result.Allowed = (bool)response[2]; result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds); } @@ -112,5 +125,6 @@ internal class RedisFixedWindowResponse internal long ExpiresAt { get; set; } internal TimeSpan RetryAfter { get; set; } internal long Count { get; set; } + internal bool Allowed { get; set; } } } diff --git a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs index 45734b6..0ba4b39 100644 --- a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs +++ b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs @@ -55,7 +55,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -71,21 +71,16 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) Window = _options.Window, }; - var response = _redisManager.TryAcquireLease(); + var response = _redisManager.TryAcquireLease(permitCount); leaseContext.Count = response.Count; leaseContext.RetryAfter = response.RetryAfter; leaseContext.ExpiresAt = DateTimeOffset.FromUnixTimeSeconds(response.ExpiresAt); - - if (leaseContext.Count > _options.PermitLimit) - { - return new FixedWindowLease(isAcquired: false, leaseContext); - } - - return new FixedWindowLease(isAcquired: true, leaseContext); + + return new FixedWindowLease(isAcquired: response.Allowed, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new FixedWindowLeaseContext { @@ -93,17 +88,12 @@ private async ValueTask AcquireAsyncCoreInternal() Window = _options.Window, }; - var response = await _redisManager.TryAcquireLeaseAsync(); + var response = await _redisManager.TryAcquireLeaseAsync(permitCount); leaseContext.Count = response.Count; leaseContext.RetryAfter = response.RetryAfter; - if (leaseContext.Count > _options.PermitLimit) - { - return new FixedWindowLease(isAcquired: false, leaseContext); - } - - return new FixedWindowLease(isAcquired: true, leaseContext); + return new FixedWindowLease(isAcquired: response.Allowed, leaseContext); } private sealed class FixedWindowLeaseContext diff --git a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs index 3708c43..00ee986 100644 --- a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs +++ b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowManager.cs @@ -16,16 +16,24 @@ internal class RedisSlidingWindowManager @"local limit = tonumber(@permit_limit) local timestamp = tonumber(@current_time) local window = tonumber(@window) + local requested = tonumber(@permit_count) + + local zaddparams = {} + for i=1,requested do + local index = i*2 + zaddparams[index-1]=timestamp + zaddparams[index]=@unique_id..':'..tostring(i) + end -- remove all requests outside current window redis.call(""zremrangebyscore"", @rate_limit_key, '-inf', timestamp - window) local count = redis.call(""zcard"", @rate_limit_key) - local allowed = count < limit + local allowed = count + requested <= limit if allowed then - redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id) + redis.call(""zadd"", @rate_limit_key, unpack(zaddparams)) end local expireAtMilliseconds = math.floor((timestamp + window) * 1000 + 1); @@ -58,7 +66,7 @@ public RedisSlidingWindowManager( StatsRateLimitKey = new RedisKey($"rl:{{{partitionKey}}}:stats"); } - internal async Task TryAcquireLeaseAsync(string requestId) + internal async Task TryAcquireLeaseAsync(string requestId, int permitCount) { var now = DateTimeOffset.UtcNow; double nowUnixTimeSeconds = now.ToUnixTimeMilliseconds() / 1000.0; @@ -73,6 +81,7 @@ internal async Task TryAcquireLeaseAsync(string requ stats_key = StatsRateLimitKey, permit_limit = (RedisValue)_options.PermitLimit, window = (RedisValue)_options.Window.TotalSeconds, + permit_count = (RedisValue)permitCount, current_time = (RedisValue)nowUnixTimeSeconds, unique_id = (RedisValue)requestId, }); @@ -88,7 +97,7 @@ internal async Task TryAcquireLeaseAsync(string requ return result; } - internal RedisSlidingWindowResponse TryAcquireLease(string requestId) + internal RedisSlidingWindowResponse TryAcquireLease(string requestId, int permitCount) { var now = DateTimeOffset.UtcNow; double nowUnixTimeSeconds = now.ToUnixTimeMilliseconds() / 1000.0; @@ -103,6 +112,7 @@ internal RedisSlidingWindowResponse TryAcquireLease(string requestId) stats_key = StatsRateLimitKey, permit_limit = (RedisValue)_options.PermitLimit, window = (RedisValue)_options.Window.TotalSeconds, + permit_count = (RedisValue)permitCount, current_time = (RedisValue)nowUnixTimeSeconds, unique_id = (RedisValue)requestId, }); diff --git a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs index 5d3eae3..df1dc18 100644 --- a/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs +++ b/src/RedisRateLimiting/SlidingWindow/RedisSlidingWindowRateLimiter.cs @@ -55,7 +55,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -72,7 +72,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) RequestId = Guid.NewGuid().ToString(), }; - var response = _redisManager.TryAcquireLease(leaseContext.RequestId); + var response = _redisManager.TryAcquireLease(leaseContext.RequestId, permitCount); leaseContext.Count = response.Count; leaseContext.Allowed = response.Allowed; @@ -85,7 +85,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return new SlidingWindowLease(isAcquired: false, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new SlidingWindowLeaseContext { @@ -94,7 +94,7 @@ private async ValueTask AcquireAsyncCoreInternal() RequestId = Guid.NewGuid().ToString(), }; - var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId); + var response = await _redisManager.TryAcquireLeaseAsync(leaseContext.RequestId, permitCount); leaseContext.Count = response.Count; leaseContext.Allowed = response.Allowed; diff --git a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs index b0bbc0a..861debf 100644 --- a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs +++ b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketManager.cs @@ -83,7 +83,7 @@ public RedisTokenBucketManager( RateLimitTimestampKey = new RedisKey($"rl:{{{partitionKey}}}:ts"); } - internal async Task TryAcquireLeaseAsync() + internal async Task TryAcquireLeaseAsync(int permitCount) { var database = _connectionMultiplexer.GetDatabase(); @@ -96,7 +96,7 @@ internal async Task TryAcquireLeaseAsync() tokens_per_period = (RedisValue)_options.TokensPerPeriod, token_limit = (RedisValue)_options.TokenLimit, replenish_period = (RedisValue)_options.ReplenishmentPeriod.TotalMilliseconds, - permit_count = (RedisValue)1D, + permit_count = (RedisValue)permitCount, }); var result = new RedisTokenBucketResponse(); @@ -111,7 +111,7 @@ internal async Task TryAcquireLeaseAsync() return result; } - internal RedisTokenBucketResponse TryAcquireLease() + internal RedisTokenBucketResponse TryAcquireLease(int permitCount) { var database = _connectionMultiplexer.GetDatabase(); @@ -124,7 +124,7 @@ internal RedisTokenBucketResponse TryAcquireLease() tokens_per_period = (RedisValue)_options.TokensPerPeriod, token_limit = (RedisValue)_options.TokenLimit, replenish_period = (RedisValue)_options.ReplenishmentPeriod.TotalMilliseconds, - permit_count = (RedisValue)1D, + permit_count = (RedisValue)permitCount, }); var result = new RedisTokenBucketResponse(); diff --git a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs index d3341f1..8e50a69 100644 --- a/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs +++ b/src/RedisRateLimiting/TokenBucket/RedisTokenBucketRateLimiter.cs @@ -60,7 +60,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.TokenLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -75,7 +75,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) Limit = _options.TokenLimit, }; - var response = _redisManager.TryAcquireLease(); + var response = _redisManager.TryAcquireLease(permitCount); leaseContext.Allowed = response.Allowed; leaseContext.Count = response.Count; @@ -89,14 +89,14 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return new TokenBucketLease(isAcquired: false, leaseContext); } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new TokenBucketLeaseContext { Limit = _options.TokenLimit, }; - var response = await _redisManager.TryAcquireLeaseAsync(); + var response = await _redisManager.TryAcquireLeaseAsync(permitCount); leaseContext.Allowed = response.Allowed; leaseContext.Count = response.Count; diff --git a/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs index 4026684..78a2c79 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/ConcurrencyUnitTests.cs @@ -120,6 +120,59 @@ public void CanAcquireResource() lease.Dispose(); } + [Fact] + public async Task SupportsPermitCountFlag_NoQueue() + { + using var limiter = new RedisConcurrencyRateLimiter( + "Test_SupportsPermitCountFlag_NoQueue_Concurrency", + new RedisConcurrencyRateLimiterOptions + { + PermitLimit = 2, + QueueLimit = 0, + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + var lease1 = await limiter.AcquireAsync(permitCount: 2); + var lease2 = await limiter.AcquireAsync(permitCount: 1); + + Assert.True(lease1.IsAcquired); + Assert.False(lease2.IsAcquired); + + lease1.Dispose(); + lease2.Dispose(); + } + + [Fact] + public async Task SupportsPermitCountFlag_WithQueue() + { + using var limiter = new RedisConcurrencyRateLimiter( + "Test_SupportsPermitCountFlag_WithQueue_Concurrency", + new RedisConcurrencyRateLimiterOptions + { + PermitLimit = 2, + QueueLimit = 2, + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + var lease1 = await limiter.AcquireAsync(permitCount: 2); + Assert.True(lease1.IsAcquired); + + var wait2 = limiter.AcquireAsync(permitCount: 2); + var wait3 = limiter.AcquireAsync(permitCount: 1); + + await Task.Delay(1000); + lease1.Dispose(); + + var lease2 = await wait2; + Assert.True(lease2.IsAcquired); + + var lease3 = await wait3; + Assert.False(lease3.IsAcquired); + + lease2.Dispose(); + lease3.Dispose(); + } + [Fact] public async Task CanAcquireResourceAsyncQueuesAndGrabsOldest() { diff --git a/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs index 2c293ff..8c243f7 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs @@ -86,5 +86,27 @@ public async Task CanAcquireAsyncResource() using var lease2 = await limiter.AcquireAsync(); Assert.False(lease2.IsAcquired); } + + [Fact] + public async Task SupportsPermitCountFlag() + { + using var limiter = new RedisFixedWindowRateLimiter( + "Test_SupportsPermitCountFlag_FW", + new RedisFixedWindowRateLimiterOptions + { + PermitLimit = 5, + Window = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(permitCount: 3); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(permitCount: 3); + Assert.False(lease2.IsAcquired); + + using var lease3 = await limiter.AcquireAsync(permitCount: 2); + Assert.True(lease3.IsAcquired); + } } } diff --git a/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs index f51db94..4211bb1 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/SlidingWindowUnitTests.cs @@ -133,5 +133,24 @@ public async Task CanAcquireAsyncResourceWithSmallWindow() using var lease4 = await limiter.AcquireAsync(); Assert.False(lease4.IsAcquired); } + + [Fact] + public async Task SupportsPermitCountFlag() + { + using var limiter = new RedisSlidingWindowRateLimiter( + "Test_SupportsPermitCountFlag_SW", + new RedisSlidingWindowRateLimiterOptions + { + PermitLimit = 3, + Window = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(permitCount: 3); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(permitCount: 1); + Assert.False(lease2.IsAcquired); + } } } diff --git a/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs index 5140e11..df91a71 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/TokenBucketUnitTests.cs @@ -107,5 +107,28 @@ public async Task CanAcquireAsyncResource() using var lease2 = await limiter.AcquireAsync(); Assert.False(lease2.IsAcquired); } + + [Fact] + public async Task SupportsPermitCountFlag() + { + using var limiter = new RedisTokenBucketRateLimiter( + "Test_SupportsPermitCountFlag_TB", + new RedisTokenBucketRateLimiterOptions + { + TokenLimit = 5, + TokensPerPeriod = 5, + ReplenishmentPeriod = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(4); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(3); + Assert.False(lease2.IsAcquired); + + using var lease3 = await limiter.AcquireAsync(1); + Assert.True(lease3.IsAcquired); + } } }