1
+ //! UDP handling
2
+ use kernel:: lib:: ring_buffer:: RingBuf ;
3
+ use kernel:: sync:: { RwLock , Mutex } ;
4
+ use kernel:: lib:: mem:: Arc ;
5
+ use kernel:: vec:: Vec ;
6
+ use crate :: Address ;
7
+ use crate :: nic:: SparsePacket ;
8
+
9
+ const IPV4_PROTO_UDP : u8 = 17 ;
10
+ /// Opened sockets
11
+ static SOCKETS : RwLock < Vec < Arc < SocketInfo > > > = RwLock :: new ( Vec :: new ( ) ) ;
12
+
13
+
14
+ pub fn init ( ) {
15
+ crate :: ipv4:: register_handler ( IPV4_PROTO_UDP , |int, src_addr, pkt|{
16
+ rx_handler ( Address :: Ipv4 ( src_addr) , Address :: Ipv4 ( int. addr ( ) ) , pkt)
17
+ } ) . unwrap ( ) ;
18
+ }
19
+ fn rx_handler ( src_addr : Address , dst_addr : Address , mut pkt : crate :: nic:: PacketReader )
20
+ {
21
+ let hdr = match PktHeader :: read ( & mut pkt)
22
+ {
23
+ Ok ( v) => v,
24
+ Err ( _) => {
25
+ log_error ! ( "Undersized packet: Ran out of data reading header" ) ;
26
+ return ;
27
+ } ,
28
+ } ;
29
+ // Check checksum.
30
+ let cksum = calc_checksum (
31
+ & hdr. encode ( ) ,
32
+ & src_addr, & dst_addr, pkt. remain ( ) ,
33
+ { let mut p = pkt. clone ( ) ; :: core:: iter:: from_fn ( move || p. read_u8 ( ) . ok ( ) ) }
34
+ ) ;
35
+ if cksum != 0 {
36
+ }
37
+ let pkt_data = pkt. clone ( ) ;
38
+ for sock in SOCKETS . read ( ) . iter ( ) {
39
+ if sock. key . local_port != hdr. dest_port {
40
+ continue ;
41
+ }
42
+ match sock. key . local_address {
43
+ Some ( a) if a != dst_addr => continue ,
44
+ _ => { } ,
45
+ }
46
+ if sock. key . remote_mask . 0 != src_addr. mask_network ( sock. key . remote_mask . 1 ) {
47
+ continue ;
48
+ }
49
+ match sock. key . remote_port {
50
+ Some ( a) if a != hdr. source_port => continue ,
51
+ _ => { } ,
52
+ }
53
+ // Matches!
54
+ sock. rx_buffer . push_packet ( dst_addr, src_addr, hdr. source_port , pkt_data. clone ( ) ) ;
55
+ }
56
+ }
57
+
58
+ pub enum Error {
59
+ /// Cannot create a new socket, the chosen local address is in use
60
+ AddressInUse ,
61
+ /// Trying to call `send` without a concrete destination
62
+ UnboundSocket ,
63
+ /// Sending to a remote address that isn't in the recieve mask
64
+ InvalidRemote ,
65
+ /// Trying to send to an address of a different type to the existing bound address
66
+ IncompatibleAddresses ,
67
+ }
68
+
69
+ /// Exposed handle to a registered socket
70
+ pub struct SocketHandle {
71
+ inner : Arc < SocketInfo > ,
72
+ }
73
+ impl SocketHandle {
74
+ pub fn new (
75
+ local_address : Option < crate :: Address > ,
76
+ local_port : u16 ,
77
+ remote_mask : ( crate :: Address , u8 ) ,
78
+ remote_port : Option < u16 > ,
79
+ ) -> Result < Self , Error > {
80
+ if local_port == 0 {
81
+ todo ! ( "Allocate a local port" ) ;
82
+ }
83
+ let key = SocketKey {
84
+ local_address,
85
+ local_port,
86
+ remote_mask,
87
+ remote_port,
88
+ } ;
89
+ // Check for an overlapping socket
90
+ let mut lh = SOCKETS . write ( ) ;
91
+ for sock in lh. iter ( ) {
92
+ if sock. key . overlaps_with ( & key) {
93
+ return Err ( Error :: AddressInUse ) ;
94
+ }
95
+ }
96
+ let rv = Arc :: new ( SocketInfo {
97
+ key,
98
+ rx_buffer : Default :: default ( ) ,
99
+ } ) ;
100
+ lh. push ( rv. clone ( ) ) ;
101
+ Ok ( SocketHandle { inner : rv } )
102
+ }
103
+ pub fn try_recv_from ( & mut self , buf : & mut [ u8 ] ) -> Option < ( usize , Address , u16 ) > {
104
+ match self . inner . rx_buffer . pop_packet ( buf) {
105
+ None => None ,
106
+ Some ( ( _dst, src, port, len) ) => {
107
+ Some ( ( len, src, port) )
108
+ }
109
+ }
110
+ }
111
+ /// Send a datagram to the single target address
112
+ pub fn send ( & self , buf : SparsePacket ) -> Result < ( ) , Error > {
113
+ match ( self . inner . key . remote_mask , self . inner . key . remote_port ) {
114
+ ( ( addr, bits) , Some ( port) ) if addr. mask_network ( bits) == addr => {
115
+ self . send_to ( addr, port, buf)
116
+ }
117
+ _ => Err ( Error :: UnboundSocket )
118
+ }
119
+ }
120
+ /// Send a datagram over this socket
121
+ pub fn send_to ( & self , addr : Address , port : u16 , buf : SparsePacket ) -> Result < ( ) , Error > {
122
+ // Check if the target address matches the remote mask
123
+ // TODO: Is this actually needed/right?
124
+ let ( d_addr, bits) = self . inner . key . remote_mask ;
125
+ if addr. mask_network ( bits) != d_addr. mask_network ( bits) {
126
+ return Err ( Error :: InvalidRemote ) ;
127
+ }
128
+ match self . inner . key . remote_port {
129
+ None => { } ,
130
+ Some ( v) if v == port => { } ,
131
+ _ => return Err ( Error :: InvalidRemote ) ,
132
+ }
133
+
134
+ // Create header
135
+ let mut hdr = PktHeader {
136
+ source_port : self . inner . key . local_port ,
137
+ dest_port : port,
138
+ length : buf. total_len ( ) as u16 ,
139
+ checksum : 0 ,
140
+ } ;
141
+ let local_addr = match self . inner . key . local_address {
142
+ Some ( a) => a,
143
+ None => todo ! ( ) ,
144
+ } ;
145
+ // - incl. checksum (if no offload)
146
+ if true {
147
+ hdr. checksum = calc_checksum (
148
+ & hdr. encode ( ) ,
149
+ & local_addr, & addr, buf. total_len ( ) ,
150
+ buf. into_iter ( ) . map ( |v| v. iter ( ) . copied ( ) ) . flatten ( )
151
+ ) ;
152
+ }
153
+ let hdr_enc = hdr. encode ( ) ;
154
+ let pkt = SparsePacket :: new_chained ( & hdr_enc, & buf) ;
155
+ // Send
156
+ match addr {
157
+ Address :: Ipv4 ( dest) => {
158
+ let Address :: Ipv4 ( source) = local_addr else { return Err ( Error :: IncompatibleAddresses ) } ;
159
+ kernel:: futures:: block_on ( crate :: ipv4:: send_packet ( source, dest, IPV4_PROTO_UDP , pkt) ) ;
160
+ }
161
+ }
162
+ Ok ( ( ) )
163
+ }
164
+ }
165
+ impl :: core:: ops:: Drop for SocketHandle {
166
+ fn drop ( & mut self ) {
167
+ let mut lh = SOCKETS . write ( ) ;
168
+ lh. retain ( |v| !Arc :: ptr_eq ( v, & self . inner ) ) ;
169
+ }
170
+ }
171
+ /// Underlying information on an open/listening socket
172
+ struct SocketInfo {
173
+ key : SocketKey ,
174
+ rx_buffer : MessagePool ,
175
+ }
176
+ /// The key part of a socket
177
+ struct SocketKey {
178
+ local_address : Option < crate :: Address > ,
179
+ local_port : u16 ,
180
+ remote_mask : ( crate :: Address , u8 ) ,
181
+ remote_port : Option < u16 > ,
182
+ }
183
+ impl SocketKey {
184
+ fn overlaps_with ( & self , other : & SocketKey ) -> bool {
185
+ // Local port: if the local port is different, then this cannot overlap
186
+ if self . local_port != other. local_port {
187
+ return false ;
188
+ }
189
+ // Remote port: If both are `Some` but different, then no overlap - otherwise possible.
190
+ match ( self . remote_port , other. remote_port ) {
191
+ ( Some ( a) , Some ( b) ) if a != b => return false ,
192
+ _ => { } ,
193
+ }
194
+ // Local address: Same as remote port... but is overlap allowed?
195
+ match ( self . local_address , other. local_address ) {
196
+ ( Some ( a) , Some ( b) ) if a != b => return false ,
197
+ _ => { } ,
198
+ }
199
+ // Remote: Check if the spans covered by the mask are disjoint.
200
+ let min_bits = u8:: min ( self . remote_mask . 1 , other. remote_mask . 1 ) ;
201
+ self . remote_mask . 0 . mask_network ( min_bits) == other. remote_mask . 0 . mask_network ( min_bits)
202
+ }
203
+ }
204
+ /// A pool of messages, stored as u16 length-delimited data in a ring-buf
205
+ struct MessagePool {
206
+ inner : Mutex < RingBuf < u8 > > ,
207
+ }
208
+ impl Default for MessagePool {
209
+ fn default ( ) -> Self {
210
+ Self { inner : Mutex :: new ( RingBuf :: new ( 1024 * 32 ) ) }
211
+ }
212
+ }
213
+ impl MessagePool {
214
+ fn push_packet ( & self , dest_addr : Address , src_addr : Address , src_port : u16 , mut pkt : crate :: nic:: PacketReader ) {
215
+ // Header:
216
+ // - port: u16
217
+ // - pkt_len: u16
218
+ // - addr_ty: u8
219
+ // - _pad: u8
220
+ // - address data
221
+ fn make_hdr < ' a > ( dst : & ' a mut [ u8 ] , len : usize , src_port : u16 , aty : u8 , dest_addr : & [ u8 ] , src_addr : & [ u8 ] ) -> & ' a [ u8 ] {
222
+ dst[ 0 ..] [ ..2 ] . copy_from_slice ( & ( len as u16 ) . to_ne_bytes ( ) ) ;
223
+ dst[ 2 ..] [ ..2 ] . copy_from_slice ( & src_port. to_ne_bytes ( ) ) ;
224
+ dst[ 4 ..] [ ..1 ] . copy_from_slice ( & aty. to_ne_bytes ( ) ) ;
225
+ dst[ 5 ..] [ ..1 ] . copy_from_slice ( & 0u8 . to_ne_bytes ( ) ) ;
226
+ dst[ 6 ..] [ ..dest_addr. len ( ) ] . copy_from_slice ( dest_addr) ;
227
+ dst[ 6 +dest_addr. len ( ) ..] [ ..src_addr. len ( ) ] . copy_from_slice ( src_addr) ;
228
+ & dst[ ..6 +dest_addr. len ( ) +src_addr. len ( ) ]
229
+ }
230
+ let mut hdr_buf = [ 0 ; 3 * 2 + 16 * 2 ] ;
231
+ let len = pkt. remain ( ) ;
232
+ let hdr = match ( dest_addr, src_addr) {
233
+ ( Address :: Ipv4 ( dest_addr) , Address :: Ipv4 ( src_addr) ) => {
234
+ make_hdr ( & mut hdr_buf, len, src_port, 0 , & dest_addr. 0 , & src_addr. 0 )
235
+ }
236
+ } ;
237
+
238
+ let mut lh = self . inner . lock ( ) ;
239
+ if lh. space ( ) < hdr. len ( ) + pkt. remain ( ) {
240
+ }
241
+ else {
242
+ for & b in hdr {
243
+ let _ = lh. push_back ( b) ;
244
+ }
245
+ while let Ok ( b) = pkt. read_u8 ( ) {
246
+ let _ = lh. push_back ( b) ;
247
+ }
248
+ }
249
+ //self.inner.push(val)
250
+ }
251
+ fn pop_packet ( & self , buf : & mut [ u8 ] ) -> Option < ( Address , Address , u16 , usize ) > {
252
+ let mut lh = self . inner . lock ( ) ;
253
+ if lh. len ( ) == 0 {
254
+ None
255
+ }
256
+ else {
257
+ assert ! ( lh. len( ) > 6 ) ;
258
+ // Get common header
259
+ let len = u16:: from_ne_bytes ( [ lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) ] ) as usize ;
260
+ let port = u16:: from_ne_bytes ( [ lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) ] ) ;
261
+ let aty = lh. pop_front ( ) . unwrap ( ) ;
262
+ let _pad = lh. pop_front ( ) . unwrap ( ) ;
263
+ // Get addresses
264
+ let ( dst, src) = match aty {
265
+ 0 => { // IPv4
266
+ let da = [ lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) , ] ;
267
+ let sa = [ lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) , lh. pop_front ( ) . unwrap ( ) , ] ;
268
+ (
269
+ Address :: Ipv4 ( crate :: ipv4:: Address ( da) ) ,
270
+ Address :: Ipv4 ( crate :: ipv4:: Address ( sa) ) ,
271
+ )
272
+ }
273
+ _ => panic ! ( "Unknown address type in packet queue" ) ,
274
+ } ;
275
+ // Get data
276
+ assert ! ( lh. len( ) > len) ;
277
+ for dst in buf. iter_mut ( ) . take ( len) {
278
+ * dst = lh. pop_front ( ) . unwrap ( ) ;
279
+ }
280
+ Some ( ( dst, src, port, len) )
281
+ }
282
+ }
283
+ }
284
+
285
+
286
+ #[ derive( Debug ) ]
287
+ struct PktHeader
288
+ {
289
+ source_port : u16 ,
290
+ dest_port : u16 ,
291
+ length : u16 ,
292
+ checksum : u16 ,
293
+ }
294
+ impl PktHeader
295
+ {
296
+ fn read ( reader : & mut crate :: nic:: PacketReader ) -> Result < Self , ( ) >
297
+ {
298
+ Ok ( PktHeader {
299
+ source_port : reader. read_u16n ( ) ?,
300
+ dest_port : reader. read_u16n ( ) ?,
301
+ length : reader. read_u16n ( ) ?,
302
+ checksum : reader. read_u16n ( ) ?,
303
+ } )
304
+ }
305
+ fn encode ( & self ) -> [ u8 ; 8 ] {
306
+ // SAFE:
307
+ unsafe { :: core:: mem:: transmute ( [
308
+ self . source_port . to_le_bytes ( ) ,
309
+ self . dest_port . to_le_bytes ( ) ,
310
+ self . length . to_le_bytes ( ) ,
311
+ self . checksum . to_le_bytes ( ) ,
312
+ ] ) }
313
+ }
314
+ }
315
+ fn calc_checksum ( hdr : & [ u8 ] , src_addr : & Address , dst_addr : & Address , data_len : usize , data : impl Iterator < Item =u8 > ) -> u16 {
316
+ let pkt_len = ( ( hdr. len ( ) + data_len) as u16 ) . to_be_bytes ( ) ;
317
+ match src_addr {
318
+ Address :: Ipv4 ( src_addr) => {
319
+ let Address :: Ipv4 ( dst_addr) = dst_addr else { panic ! ( "Mismatched address types" ) } ;
320
+ let ph = [
321
+ src_addr. 0 [ 0 ] , src_addr. 0 [ 1 ] , src_addr. 0 [ 2 ] , src_addr. 0 [ 3 ] ,
322
+ dst_addr. 0 [ 0 ] , dst_addr. 0 [ 1 ] , dst_addr. 0 [ 2 ] , dst_addr. 0 [ 3 ] ,
323
+ 0 , IPV4_PROTO_UDP ,
324
+ pkt_len[ 0 ] , pkt_len[ 1 ] ,
325
+ ] ;
326
+ calc_checksum_inner ( hdr, & ph, data)
327
+ } ,
328
+ }
329
+ }
330
+ fn calc_checksum_inner ( hdr : & [ u8 ] , ph : & [ u8 ] , data : impl Iterator < Item =u8 > ) -> u16
331
+ {
332
+ return super :: ipv4:: calculate_checksum ( Words (
333
+ Iterator :: chain ( ph. iter ( ) . copied ( ) , hdr. iter ( ) . copied ( ) )
334
+ . chain ( data )
335
+ ) ) ;
336
+ struct Words < I > ( I ) ;
337
+ impl < I > Iterator for Words < I >
338
+ where I : Iterator < Item =u8 >
339
+ {
340
+ type Item = u16 ;
341
+
342
+ fn next ( & mut self ) -> Option < Self :: Item > {
343
+ // NOTE: This only really works on fused iterators
344
+ match ( self . 0 . next ( ) , self . 0 . next ( ) ) {
345
+ ( Some ( a) , Some ( b) ) => Some ( u16:: from_be_bytes ( [ a, b] ) ) ,
346
+ ( Some ( a) , None ) => Some ( u16:: from_be_bytes ( [ a, 0 ] ) ) ,
347
+ ( None , _) => None ,
348
+ }
349
+ }
350
+ }
351
+ }
0 commit comments