@@ -6,12 +6,17 @@ use std::time::Duration;
6
6
7
7
use anyhow:: { anyhow, Error , Result } ;
8
8
use async_trait:: async_trait;
9
- use http:: { uri:: Scheme , HeaderMap , HeaderValue , Uri } ;
10
- use hyper:: client:: connect:: Connect ;
11
- use hyper:: { client:: HttpConnector , service:: Service , Client } ;
9
+ use http:: { header, uri:: Scheme , HeaderMap , HeaderName , Uri } ;
10
+ use http_body_util:: { BodyExt , Full } ;
11
+ use hyper:: body:: Bytes ;
12
+ use hyper:: rt:: ReadBufCursor ;
13
+ use hyper_util:: client:: legacy:: connect:: { Connect , HttpConnector } ;
14
+ use hyper_util:: client:: legacy:: Client ;
15
+ use hyper_util:: rt:: TokioExecutor ;
12
16
use pin_project:: pin_project;
13
- use tokio :: io :: { AsyncRead , AsyncWrite , ReadBuf } ;
17
+ use smol_str :: { SmolStr , ToSmolStr } ;
14
18
use tokio:: net:: TcpStream ;
19
+ use tower_service:: Service ;
15
20
use tracing:: { debug, trace, warn} ;
16
21
17
22
use reactor:: gcore:: fastedge:: http:: Headers ;
@@ -20,9 +25,7 @@ use reactor::gcore::fastedge::{
20
25
http_client:: Host ,
21
26
} ;
22
27
23
- const HOST_HEADER_NAME : & str = "host" ;
24
-
25
- type HeaderList = Vec < ( String , String ) > ;
28
+ type HeaderNameList = Vec < SmolStr > ;
26
29
27
30
#[ derive( Clone , Copy , Debug , PartialEq , Eq ) ]
28
31
pub enum BackendStrategy {
@@ -47,17 +50,17 @@ pub struct FastEdgeConnector {
47
50
48
51
#[ derive( Clone , Debug ) ]
49
52
pub struct Backend < C > {
50
- client : Client < C > ,
53
+ client : Client < C , Full < Bytes > > ,
51
54
uri : Uri ,
52
- propagate_headers : HeaderList ,
53
- propagate_header_names : Vec < String > ,
55
+ propagate_headers : HeaderMap ,
56
+ propagate_header_names : HeaderNameList ,
54
57
max_sub_requests : usize ,
55
58
strategy : BackendStrategy ,
56
59
}
57
60
58
61
pub struct Builder {
59
62
uri : Uri ,
60
- propagate_header_names : Vec < String > ,
63
+ propagate_header_names : HeaderNameList ,
61
64
max_sub_requests : usize ,
62
65
strategy : BackendStrategy ,
63
66
}
@@ -67,7 +70,7 @@ impl Builder {
67
70
self . uri = uri;
68
71
self
69
72
}
70
- pub fn propagate_headers_names ( & mut self , propagate : Vec < String > ) -> & mut Self {
73
+ pub fn propagate_headers_names ( & mut self , propagate : HeaderNameList ) -> & mut Self {
71
74
self . propagate_header_names = propagate;
72
75
self
73
76
}
@@ -80,15 +83,15 @@ impl Builder {
80
83
where
81
84
C : Connect + Clone ,
82
85
{
83
- let client = hyper :: Client :: builder ( )
86
+ let client = hyper_util :: client :: legacy :: Client :: builder ( TokioExecutor :: new ( ) )
84
87
. set_host ( false )
85
88
. pool_idle_timeout ( Duration :: from_secs ( 30 ) )
86
89
. build ( connector) ;
87
90
88
91
Backend {
89
92
client,
90
93
uri : self . uri . to_owned ( ) ,
91
- propagate_headers : vec ! [ ] ,
94
+ propagate_headers : HeaderMap :: new ( ) ,
92
95
propagate_header_names : self . propagate_header_names . to_owned ( ) ,
93
96
max_sub_requests : self . max_sub_requests ,
94
97
strategy : self . strategy ,
@@ -106,43 +109,57 @@ impl<C> Backend<C> {
106
109
}
107
110
}
108
111
109
- pub fn uri ( & self ) -> & Uri {
110
- & self . uri
112
+ pub fn uri ( & self ) -> Uri {
113
+ self . uri . to_owned ( )
114
+ }
115
+
116
+ pub fn propagate_header_names ( & self ) -> Vec < SmolStr > {
117
+ self . propagate_header_names . to_owned ( )
111
118
}
112
119
113
120
/// Propagate filtered headers from original requests
114
- pub fn propagate_headers ( & mut self , headers : & HeaderMap < HeaderValue > ) -> Result < ( ) > {
121
+ pub fn propagate_headers ( & mut self , headers : HeaderMap ) -> Result < ( ) > {
115
122
self . propagate_headers . clear ( ) ;
116
123
117
124
if self . strategy == BackendStrategy :: FastEdge {
118
125
let server_name = headers
119
- . get ( "Server_Name " )
126
+ . get ( "server_name " )
120
127
. and_then ( |v| v. to_str ( ) . ok ( ) )
121
128
. ok_or ( anyhow ! ( "header Server_name is missing" ) ) ?;
122
- self . propagate_headers
123
- . push ( ( "Host" . to_string ( ) , be_base_domain ( server_name) ) ) ;
129
+ self . propagate_headers . insert (
130
+ HeaderName :: from_static ( "host" ) ,
131
+ be_base_domain ( server_name) . parse ( ) ?,
132
+ ) ;
124
133
}
125
-
126
- for header_name in self . propagate_header_names . iter ( ) {
127
- if let Some ( value) = headers. get ( header_name) . and_then ( |v| v. to_str ( ) . ok ( ) ) {
128
- trace ! ( "add original request header: {}={}" , header_name, value) ;
129
- self . propagate_headers
130
- . push ( ( header_name. to_string ( ) , value. to_string ( ) ) ) ;
134
+ let headers = headers. into_iter ( ) . filter ( |( k, _) | {
135
+ if let Some ( name) = k {
136
+ self . propagate_header_names . contains ( & name. to_smolstr ( ) )
137
+ } else {
138
+ false
131
139
}
132
- }
140
+ } ) ;
141
+ self . propagate_headers . extend ( headers) ;
142
+
133
143
Ok ( ( ) )
134
144
}
135
145
136
- fn make_request ( & self , req : Request ) -> Result < http:: Request < hyper:: Body > > {
146
+ fn propagate_headers_vec ( & self ) -> Vec < ( String , String ) > {
147
+ self . propagate_headers
148
+ . iter ( )
149
+ . filter_map ( |( k, v) | v. to_str ( ) . ok ( ) . map ( |v| ( k. to_string ( ) , v. to_string ( ) ) ) )
150
+ . collect :: < Vec < ( String , String ) > > ( )
151
+ }
152
+
153
+ fn make_request ( & self , req : Request ) -> Result < http:: Request < Full < Bytes > > > {
137
154
trace ! ( "strategy: {:?}" , self . strategy) ;
138
155
let builder = match self . strategy {
139
156
BackendStrategy :: Direct => {
140
157
let mut headers = req. headers . into_iter ( ) . collect :: < Vec < ( String , String ) > > ( ) ;
141
- headers. extend ( self . propagate_headers . clone ( ) ) ;
158
+ headers. extend ( self . propagate_headers_vec ( ) ) ;
142
159
// CLI has to set Host header from URL, if it is not set already by the request
143
160
if !headers
144
161
. iter ( )
145
- . any ( |( k, _) | k. eq_ignore_ascii_case ( HOST_HEADER_NAME ) )
162
+ . any ( |( k, _) | k. eq_ignore_ascii_case ( header :: HOST . as_str ( ) ) )
146
163
{
147
164
if let Ok ( uri) = req. uri . parse :: < Uri > ( ) {
148
165
if let Some ( host) = uri. authority ( ) . map ( |a| {
@@ -152,7 +169,7 @@ impl<C> Backend<C> {
152
169
a. host ( ) . to_string ( )
153
170
}
154
171
} ) {
155
- headers. push ( ( HOST_HEADER_NAME . to_string ( ) , host) )
172
+ headers. push ( ( header :: HOST . as_str ( ) . to_string ( ) , host) )
156
173
}
157
174
}
158
175
}
@@ -216,8 +233,12 @@ impl<C> Backend<C> {
216
233
} )
217
234
. collect :: < Vec < ( String , String ) > > ( ) ;
218
235
219
- headers. extend ( backend_headers ( & original_url, original_host) ) ;
220
- headers. extend ( self . propagate_headers . clone ( ) ) ;
236
+ headers. push ( ( "fastedge-hostname" . to_string ( ) , original_host) ) ;
237
+ headers. push ( (
238
+ "fastedge-scheme" . to_string ( ) ,
239
+ original_url. scheme_str ( ) . unwrap_or ( "http" ) . to_string ( ) ,
240
+ ) ) ;
241
+ headers. extend ( self . propagate_headers_vec ( ) ) ;
221
242
222
243
let host = canonical_host_name ( & headers, & original_url) ?;
223
244
let url = canonical_url ( & original_url, & host, self . uri . path ( ) ) ?;
@@ -240,7 +261,7 @@ impl<C> Backend<C> {
240
261
} ;
241
262
debug ! ( "request builder: {:?}" , builder) ;
242
263
let body = req. body . unwrap_or_default ( ) ;
243
- builder. body ( hyper :: Body :: from ( body) ) . map_err ( Error :: msg )
264
+ Ok ( builder. body ( Full :: new ( Bytes :: from ( body) ) ) ? )
244
265
}
245
266
}
246
267
@@ -280,7 +301,7 @@ where
280
301
None
281
302
} ;
282
303
283
- let body_bytes = hyper :: body:: to_bytes ( body ) . await ?;
304
+ let body_bytes = body. collect ( ) . await ?. to_bytes ( ) ;
284
305
let body = Some ( body_bytes. to_vec ( ) ) ;
285
306
286
307
trace ! ( ?status, ?headers, len = body_bytes. len( ) , "reply" ) ;
@@ -350,16 +371,6 @@ fn canonical_url(original_url: &Uri, canonical_host: &str, backend_path: &str) -
350
371
. map_err ( Error :: msg)
351
372
}
352
373
353
- fn backend_headers ( original_url : & Uri , original_host : String ) -> HeaderList {
354
- vec ! [ ( "Fastedge-Hostname" . to_string( ) , original_host) , (
355
- "Fastedge-Scheme" . to_string( ) ,
356
- original_url
357
- . scheme_str( )
358
- . unwrap_or( "http" )
359
- . to_string( ) ,
360
- ) ]
361
- }
362
-
363
374
impl FastEdgeConnector {
364
375
pub fn new ( backend : Uri ) -> Self {
365
376
let mut inner = HttpConnector :: new ( ) ;
@@ -385,54 +396,63 @@ impl Service<Uri> for FastEdgeConnector {
385
396
Box :: pin ( async move {
386
397
let conn = connect_fut
387
398
. await
388
- . map ( |inner| Connection { inner } )
399
+ . map ( |inner| Connection {
400
+ inner : inner. into_inner ( ) ,
401
+ } )
389
402
. map_err ( Box :: new) ?;
390
403
Ok ( conn)
391
404
} )
392
405
}
393
406
}
394
407
395
- impl AsyncRead for Connection {
408
+ impl hyper :: rt :: Read for Connection {
396
409
fn poll_read (
397
410
self : Pin < & mut Self > ,
398
411
cx : & mut Context < ' _ > ,
399
- buf : & mut ReadBuf < ' _ > ,
400
- ) -> Poll < std:: io:: Result < ( ) > > {
401
- let this = self . project ( ) ;
402
- this. inner . poll_read ( cx, buf)
412
+ mut buf : ReadBufCursor < ' _ > ,
413
+ ) -> Poll < std:: result:: Result < ( ) , std:: io:: Error > > {
414
+ let n = unsafe {
415
+ let mut tbuf = tokio:: io:: ReadBuf :: uninit ( buf. as_mut ( ) ) ;
416
+ match tokio:: io:: AsyncRead :: poll_read ( self . project ( ) . inner , cx, & mut tbuf) {
417
+ Poll :: Ready ( Ok ( ( ) ) ) => tbuf. filled ( ) . len ( ) ,
418
+ other => return other,
419
+ }
420
+ } ;
421
+
422
+ unsafe {
423
+ buf. advance ( n) ;
424
+ }
425
+ Poll :: Ready ( Ok ( ( ) ) )
403
426
}
404
427
}
405
428
406
- impl AsyncWrite for Connection {
429
+ impl hyper :: rt :: Write for Connection {
407
430
fn poll_write (
408
431
self : Pin < & mut Self > ,
409
432
cx : & mut Context < ' _ > ,
410
433
buf : & [ u8 ] ,
411
434
) -> Poll < std:: result:: Result < usize , std:: io:: Error > > {
412
- let this = self . project ( ) ;
413
- this. inner . poll_write ( cx, buf)
435
+ tokio:: io:: AsyncWrite :: poll_write ( self . project ( ) . inner , cx, buf)
414
436
}
415
437
416
438
fn poll_flush (
417
439
self : Pin < & mut Self > ,
418
440
cx : & mut Context < ' _ > ,
419
441
) -> Poll < std:: result:: Result < ( ) , std:: io:: Error > > {
420
- let this = self . project ( ) ;
421
- this. inner . poll_flush ( cx)
442
+ tokio:: io:: AsyncWrite :: poll_flush ( self . project ( ) . inner , cx)
422
443
}
423
444
424
445
fn poll_shutdown (
425
446
self : Pin < & mut Self > ,
426
447
cx : & mut Context < ' _ > ,
427
448
) -> Poll < std:: result:: Result < ( ) , std:: io:: Error > > {
428
- let this = self . project ( ) ;
429
- this. inner . poll_shutdown ( cx)
449
+ tokio:: io:: AsyncWrite :: poll_shutdown ( self . project ( ) . inner , cx)
430
450
}
431
451
}
432
452
433
- impl hyper :: client:: connect:: Connection for Connection {
434
- fn connected ( & self ) -> hyper :: client:: connect:: Connected {
435
- hyper :: client:: connect:: Connected :: new ( )
453
+ impl hyper_util :: client:: legacy :: connect:: Connection for Connection {
454
+ fn connected ( & self ) -> hyper_util :: client:: legacy :: connect:: Connected {
455
+ hyper_util :: client:: legacy :: connect:: Connected :: new ( )
436
456
}
437
457
}
438
458
0 commit comments