Skip to content

Commit 5cb0092

Browse files
committed
TCP - Fix checksum calculation
1 parent f0275d4 commit 5cb0092

File tree

3 files changed

+113
-80
lines changed

3 files changed

+113
-80
lines changed

Kernel/Modules/network/ipv4.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ use crate::nic::MacAddr;
1010
mod address;
1111
pub use self::address::Address;
1212

13+
pub mod checksum;
14+
pub use self::checksum::from_words as calculate_checksum;
15+
1316
mod headers;
1417
use self::headers::Ipv4Header;
1518

@@ -83,21 +86,6 @@ pub fn listen_raw(local_addr: Address, proto: u8, remote_mask: (Address, u8)) ->
8386
{
8487
}
8588

86-
// Calculate a checksum of a sequence of NATIVE ENDIAN (not network) 16-bit words
87-
pub fn calculate_checksum(words: impl Iterator<Item=u16>) -> u16
88-
{
89-
let mut sum = 0;
90-
for v in words
91-
{
92-
sum += v as usize;
93-
}
94-
while sum > 0xFFFF
95-
{
96-
sum = (sum & 0xFFFF) + (sum >> 16);
97-
}
98-
!sum as u16
99-
}
100-
10189
/// Send a raw packet
10290
pub async fn send_packet(source: Address, dest: Address, proto: u8, pkt: crate::nic::SparsePacket<'_>) -> Result<(),()>
10391
{
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
// Calculate a checksum of a sequence of NATIVE ENDIAN (not network) 16-bit words
3+
pub fn from_words(words: impl Iterator<Item=u16>) -> u16
4+
{
5+
let mut sum = 0;
6+
for v in words
7+
{
8+
sum += v as usize;
9+
}
10+
while sum > 0xFFFF
11+
{
12+
sum = (sum & 0xFFFF) + (sum >> 16);
13+
}
14+
!sum as u16
15+
}
16+
17+
pub fn from_bytes(bytes: impl Iterator<Item=u8>) -> u16
18+
{
19+
struct Words<I>(I);
20+
impl<I> Iterator for Words<I>
21+
where I: Iterator<Item=u8>
22+
{
23+
type Item = u16;
24+
25+
fn next(&mut self) -> Option<Self::Item> {
26+
// NOTE: This only really works on fused iterators
27+
match (self.0.next(),self.0.next()) {
28+
(Some(a),Some(b)) => Some(u16::from_be_bytes([a,b])),
29+
(Some(a),None) => Some(u16::from_be_bytes([a,0])),
30+
(None,_) => None,
31+
}
32+
}
33+
}
34+
35+
from_words(Words(bytes.fuse()))
36+
}
37+
38+
pub fn from_reader(mut reader: crate::nic::PacketReader) -> u16
39+
{
40+
let len = reader.remain();
41+
from_bytes( (0 .. len).map(|_| reader.read_u8().unwrap()) )
42+
}

Kernel/Modules/network/tcp.rs

Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -119,47 +119,18 @@ fn rx_handler(src_addr: Address, dest_addr: Address, mut pkt: crate::nic::Packet
119119
}
120120

121121
// Validate checksum.
122+
let checksum = calculate_checksum(src_addr, dest_addr, &hdr, pkt.remain(),
122123
{
123-
use crate::ipv4::calculate_checksum;
124-
125-
let packet_len = pre_header_reader.remain();
126-
// Pseudo header for checksum
127-
let sum_pseudo = match src_addr
128-
{
129-
Address::Ipv4(s) => {
130-
let Address::Ipv4(d) = dest_addr else { unreachable!() };
131-
calculate_checksum([
132-
// Big endian stores MSB first, so write the high word first
133-
(s.as_u32() >> 16) as u16, (s.as_u32() >> 0) as u16,
134-
(d.as_u32() >> 16) as u16, (d.as_u32() >> 0) as u16,
135-
IPV4_PROTO_TCP as u16, packet_len as u16,
136-
].iter().copied())
137-
},
138-
Address::Ipv6(s) => {
139-
let Address::Ipv6(d) = dest_addr else { unreachable!() };
140-
calculate_checksum([
141-
// Big endian stores MSB first, so write the high word first
142-
s.words()[0], s.words()[1], s.words()[2], s.words()[3],
143-
s.words()[4], s.words()[5], s.words()[6], s.words()[7],
144-
d.words()[0], d.words()[1], d.words()[2], d.words()[3],
145-
d.words()[4], d.words()[5], d.words()[6], d.words()[7],
146-
IPV4_PROTO_TCP as u16, packet_len as u16,
147-
].iter().copied())
148-
}
149-
};
150-
let sum_header = hdr.checksum();
151-
let sum_options_and_data = {
152-
let mut pkt = pkt.clone();
153-
let psum_whole = !calculate_checksum( (0 .. (pre_header_reader.remain() - hdr_len) / 2).map(|_| pkt.read_u16n().unwrap()) );
154-
// Final byte is decoded as if there was a zero after it (so as 0x??00)
155-
let psum_partial = if pkt.remain() > 0 { (pkt.read_u8().unwrap() as u16) << 8} else { 0 };
156-
calculate_checksum([psum_whole, psum_partial].iter().copied())
157-
};
158-
let sum_total = calculate_checksum([ !sum_pseudo, !sum_header, !sum_options_and_data ].iter().copied());
159-
if sum_total != 0 {
160-
log_error!("Incorrect checksum: 0x{:04x} != 0", sum_total);
161-
// TODO: Discard the packet.
124+
let mut pkt = pkt.clone();
125+
let psum_whole = !crate::ipv4::calculate_checksum( (0 .. pkt.remain() / 2).map(|_| pkt.read_u16n().unwrap()) );
126+
// Final byte is decoded as if there was a zero after it (so as 0x??00)
127+
let psum_partial = if pkt.remain() > 0 { (pkt.read_u8().unwrap() as u16) << 8} else { 0 };
128+
crate::ipv4::calculate_checksum([psum_whole, psum_partial].iter().copied())
162129
}
130+
);
131+
if checksum != 0 {
132+
log_error!("Incorrect checksum: 0x{:04x} != 0", checksum);
133+
// TODO: Discard the packet.
163134
}
164135

165136
// Options
@@ -252,6 +223,40 @@ fn rx_handler(src_addr: Address, dest_addr: Address, mut pkt: crate::nic::Packet
252223
// Otherwise, drop
253224
}
254225

226+
fn calculate_checksum(src_addr: Address, dest_addr: Address, hdr: &PktHeader, tail_len: usize, tail_sum: u16) -> u16
227+
{
228+
use crate::ipv4::calculate_checksum as ip_checksum;
229+
230+
let packet_len = (5*4) + tail_len;
231+
232+
// Pseudo header for checksum
233+
let sum_pseudo = match src_addr
234+
{
235+
Address::Ipv4(s) => {
236+
let Address::Ipv4(d) = dest_addr else { unreachable!() };
237+
ip_checksum([
238+
// Big endian stores MSB first, so write the high word first
239+
(s.as_u32() >> 16) as u16, (s.as_u32() >> 0) as u16,
240+
(d.as_u32() >> 16) as u16, (d.as_u32() >> 0) as u16,
241+
IPV4_PROTO_TCP as u16, packet_len as u16,
242+
].iter().copied())
243+
},
244+
Address::Ipv6(s) => {
245+
let Address::Ipv6(d) = dest_addr else { unreachable!() };
246+
ip_checksum([
247+
// Big endian stores MSB first, so write the high word first
248+
s.words()[0], s.words()[1], s.words()[2], s.words()[3],
249+
s.words()[4], s.words()[5], s.words()[6], s.words()[7],
250+
d.words()[0], d.words()[1], d.words()[2], d.words()[3],
251+
d.words()[4], d.words()[5], d.words()[6], d.words()[7],
252+
IPV4_PROTO_TCP as u16, packet_len as u16,
253+
].iter().copied())
254+
}
255+
};
256+
let sum_header = hdr.checksum();
257+
ip_checksum([ !sum_pseudo, !sum_header, !tail_sum ].iter().copied())
258+
}
259+
255260
#[derive(Copy,Clone,PartialEq,PartialOrd,Eq,Ord)]
256261
struct ListenPair(Option<Address>, u16);
257262
impl ::core::fmt::Debug for ListenPair
@@ -305,17 +310,29 @@ impl Quad
305310
// TODO: Any options required?
306311
let options_bytes = &[];
307312
let opts_len_rounded = ((options_bytes.len() + 3) / 4) * 4;
308-
let hdr = PktHeader {
309-
source_port: self.local_port,
310-
dest_port: self.remote_port,
311-
sequence_number: seq,
312-
acknowledgement_number: ack,
313-
data_offset: ((5 + opts_len_rounded/4) << 4) as u8 | 0,
314-
flags: flags,
315-
window_size: window_size,
316-
checksum: 0, // To be filled afterwards
317-
urgent_pointer: 0,
318-
}.as_bytes();
313+
let hdr = {
314+
let mut hdr = PktHeader {
315+
source_port: self.local_port,
316+
dest_port: self.remote_port,
317+
sequence_number: seq,
318+
acknowledgement_number: ack,
319+
data_offset: ((5 + opts_len_rounded/4) << 4) as u8 | 0,
320+
flags,
321+
window_size,
322+
checksum: 0, // To be filled afterwards
323+
urgent_pointer: 0,
324+
};
325+
hdr.checksum = calculate_checksum(
326+
self.local_addr, self.remote_addr,
327+
&hdr,
328+
data1.len()+data2.len(),
329+
super::ipv4::checksum::from_bytes(
330+
// TODO: options (padded to multiple of 4 bytes)
331+
Iterator::chain(data1.iter().copied(), data2.iter().copied() )
332+
)
333+
);
334+
hdr.as_bytes()
335+
};
319336
// Calculate checksum
320337

321338
// Create sparse packet chain
@@ -418,22 +435,8 @@ impl PktHeader
418435
(self.urgent_pointer >> 0) as u8,
419436
]
420437
}
421-
fn as_u16s(&self) -> [u16; 5*2] {
422-
[
423-
self.source_port,
424-
self.dest_port,
425-
(self.sequence_number >> 16) as u16,
426-
(self.sequence_number >> 0) as u16,
427-
(self.acknowledgement_number >> 16) as u16,
428-
(self.acknowledgement_number >> 0) as u16,
429-
(self.data_offset as u16) << 8 | (self.flags as u16),
430-
self.window_size,
431-
self.checksum,
432-
self.urgent_pointer,
433-
]
434-
}
435438
fn checksum(&self) -> u16 {
436-
crate::ipv4::calculate_checksum(self.as_u16s().iter().cloned())
439+
crate::ipv4::checksum::from_bytes(self.as_bytes().iter().copied())
437440
}
438441
}
439442

0 commit comments

Comments
 (0)