1
1
using StackExchange . Redis ;
2
2
using System ;
3
+ using System . Collections . Generic ;
3
4
using System . Threading . RateLimiting ;
4
5
using System . Threading . Tasks ;
5
6
@@ -18,6 +19,7 @@ internal class RedisConcurrencyManager
18
19
local queue_limit = tonumber(@queue_limit)
19
20
local try_enqueue = tonumber(@try_enqueue)
20
21
local timestamp = tonumber(@current_time)
22
+ local requested = tonumber(@permit_count)
21
23
-- max seconds it takes to complete a request
22
24
local ttl = 60
23
25
@@ -29,10 +31,19 @@ internal class RedisConcurrencyManager
29
31
end
30
32
31
33
local count = redis.call(""zcard"", @rate_limit_key)
32
- local allowed = count < limit
34
+ local allowed = count + requested <= limit
33
35
local queued = false
34
36
local queue_count = 0
35
37
38
+ local addparams = {}
39
+ local remparams = {}
40
+ for i=1,requested do
41
+ local index = i*2
42
+ addparams[index-1]=timestamp
43
+ addparams[index]=@unique_id..':'..tostring(i)
44
+ remparams[i]=addparams[index]
45
+ end
46
+
36
47
if allowed
37
48
then
38
49
@@ -45,23 +56,23 @@ internal class RedisConcurrencyManager
45
56
if queue_count == 0 or try_enqueue == 0
46
57
then
47
58
48
- redis.call(""zadd"", @rate_limit_key, timestamp, @unique_id )
59
+ redis.call(""zadd"", @rate_limit_key, unpack(addparams) )
49
60
50
61
if queue_limit > 0
51
62
then
52
63
-- remove from pending queue
53
- redis.call(""zrem"", @queue_key, @unique_id )
64
+ redis.call(""zrem"", @queue_key, unpack(remparams) )
54
65
end
55
66
56
67
else
57
68
-- queue the current request next in line if we have any requests in the pending queue
58
69
allowed = false
59
70
60
- queued = queue_count + count < limit + queue_limit
71
+ queued = queue_count + count + requested <= limit + queue_limit
61
72
62
73
if queued
63
74
then
64
- redis.call(""zadd"", @queue_key, timestamp, @unique_id )
75
+ redis.call(""zadd"", @queue_key, unpack(addparams) )
65
76
end
66
77
67
78
end
@@ -72,23 +83,23 @@ internal class RedisConcurrencyManager
72
83
then
73
84
74
85
queue_count = redis.call(""zcard"", @queue_key)
75
- queued = queue_count < queue_limit
86
+ queued = queue_count + requested <= queue_limit
76
87
77
88
if queued
78
89
then
79
- redis.call(""zadd"", @queue_key, timestamp, @unique_id )
90
+ redis.call(""zadd"", @queue_key, unpack(addparams) )
80
91
end
81
92
82
93
end
83
94
end
84
95
85
96
if allowed
86
97
then
87
- redis.call(""hincrby"", @stats_key, 'total_successful', 1 )
98
+ redis.call(""hincrby"", @stats_key, 'total_successful', requested )
88
99
else
89
100
if queued == false and try_enqueue == 1
90
101
then
91
- redis.call(""hincrby"", @stats_key, 'total_failed', 1 )
102
+ redis.call(""hincrby"", @stats_key, 'total_failed', requested )
92
103
end
93
104
end
94
105
@@ -114,7 +125,7 @@ public RedisConcurrencyManager(
114
125
StatsRateLimitKey = new RedisKey ( $ "rl:{{{partitionKey}}}:stats") ;
115
126
}
116
127
117
- internal async Task < RedisConcurrencyResponse > TryAcquireLeaseAsync ( string requestId , bool tryEnqueue = false )
128
+ internal async Task < RedisConcurrencyResponse > TryAcquireLeaseAsync ( string requestId , int permitCount , bool tryEnqueue = false )
118
129
{
119
130
var nowUnixTimeSeconds = DateTimeOffset . UtcNow . ToUnixTimeSeconds ( ) ;
120
131
@@ -132,6 +143,7 @@ internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string reques
132
143
stats_key = StatsRateLimitKey ,
133
144
current_time = nowUnixTimeSeconds ,
134
145
unique_id = requestId ,
146
+ permit_count = permitCount
135
147
} ) ;
136
148
137
149
var result = new RedisConcurrencyResponse ( ) ;
@@ -147,7 +159,7 @@ internal async Task<RedisConcurrencyResponse> TryAcquireLeaseAsync(string reques
147
159
return result ;
148
160
}
149
161
150
- internal RedisConcurrencyResponse TryAcquireLease ( string requestId , bool tryEnqueue = false )
162
+ internal RedisConcurrencyResponse TryAcquireLease ( string requestId , int permitCount , bool tryEnqueue = false )
151
163
{
152
164
var nowUnixTimeSeconds = DateTimeOffset . UtcNow . ToUnixTimeSeconds ( ) ;
153
165
@@ -165,6 +177,7 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu
165
177
stats_key = StatsRateLimitKey ,
166
178
current_time = nowUnixTimeSeconds ,
167
179
unique_id = requestId ,
180
+ permit_count = permitCount
168
181
} ) ;
169
182
170
183
var result = new RedisConcurrencyResponse ( ) ;
@@ -180,22 +193,41 @@ internal RedisConcurrencyResponse TryAcquireLease(string requestId, bool tryEnqu
180
193
return result ;
181
194
}
182
195
183
- internal void ReleaseLease ( string requestId )
196
+ internal void ReleaseLease ( string requestId , int permitCount )
184
197
{
185
198
var database = _connectionMultiplexer . GetDatabase ( ) ;
186
- database . SortedSetRemove ( RateLimitKey , requestId ) ;
199
+
200
+ for ( var i = 1 ; i <= permitCount ; i ++ )
201
+ {
202
+ database . SortedSetRemove ( RateLimitKey , $ "{ requestId } :{ i } ") ;
203
+ }
187
204
}
188
205
189
- internal async Task ReleaseLeaseAsync ( string requestId )
206
+ internal Task ReleaseLeaseAsync ( string requestId , int permitCount )
190
207
{
191
208
var database = _connectionMultiplexer . GetDatabase ( ) ;
192
- await database . SortedSetRemoveAsync ( RateLimitKey , requestId ) ;
209
+ var tasks = new List < Task > ( permitCount ) ;
210
+
211
+ for ( var i = 1 ; i <= permitCount ; i ++ )
212
+ {
213
+ tasks . Add ( database . SortedSetRemoveAsync ( RateLimitKey , $ "{ requestId } :{ i } ") ) ;
214
+ }
215
+
216
+ return Task . WhenAll ( tasks ) ;
193
217
}
194
218
195
- internal async Task ReleaseQueueLeaseAsync ( string requestId )
219
+ internal Task ReleaseQueueLeaseAsync ( string requestId , int permitCount )
196
220
{
197
221
var database = _connectionMultiplexer . GetDatabase ( ) ;
198
- await database . SortedSetRemoveAsync ( QueueRateLimitKey , requestId ) ;
222
+
223
+ var tasks = new List < Task > ( permitCount ) ;
224
+
225
+ for ( var i = 1 ; i <= permitCount ; i ++ )
226
+ {
227
+ tasks . Add ( database . SortedSetRemoveAsync ( QueueRateLimitKey , $ "{ requestId } :{ i } ") ) ;
228
+ }
229
+
230
+ return Task . WhenAll ( tasks ) ;
199
231
}
200
232
201
233
internal RateLimiterStatistics ? GetStatistics ( )
0 commit comments