Skip to content

Commit 3386092

Browse files
committed
feat: use tokio fully for async
1 parent 7b26a7c commit 3386092

File tree

10 files changed

+132
-89
lines changed

10 files changed

+132
-89
lines changed

Cargo.lock

Lines changed: 11 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/game-solver/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,23 @@ edition = "2021"
1010

1111
[features]
1212
"xxhash" = ["dep:twox-hash"]
13-
"rayon" = ["dep:rayon", "xxhash", "dep:sysinfo", "dep:moka"]
13+
"rayon" = ["xxhash", "dep:sysinfo", "dep:moka", "dep:tokio", "dep:tokio-util"]
1414
"js" = ["moka/js"]
1515

1616
[dependencies]
1717
# dfdx = { git = "https://github.com/coreylowman/dfdx.git", rev = "4722a99", optional = true }
1818
moka = { version = "0.12", optional = true, features = ["future"] }
1919
rand = { version = "0.8", optional = true }
20-
rayon = { version = "1.8", optional = true }
2120
sysinfo = { version = "0.30", optional = true }
2221
twox-hash = { version = "1.6", optional = true }
2322
itertools = { version = "0.13" }
2423
futures = "0.3.30"
2524
thiserror = "1.0"
2625
castaway = "0.2.3"
2726
fxhash = "0.2.1"
27+
smallvec = "1.13.2"
28+
tokio-util = { version = "0.7.13", optional = true }
29+
tokio = { version = "1.43.0", optional = true }
2830

2931
[package.metadata.docs.rs]
3032
all-features = true

crates/game-solver/src/lib.rs

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ pub mod transposition;
1616

1717
use core::panic;
1818
#[cfg(feature = "rayon")]
19+
use tokio_util::sync::CancellationToken;
20+
#[cfg(feature = "rayon")]
1921
use std::hash::BuildHasher;
22+
#[cfg(feature = "rayon")]
23+
use std::sync::Arc;
2024
use std::sync::atomic::Ordering;
2125

2226
use game::{upper_bound, GameState};
@@ -32,6 +36,8 @@ use thiserror::Error;
3236
pub enum GameSolveError<T: Game> {
3337
#[error("could not make a move")]
3438
MoveError(T::MoveError),
39+
#[error("solving was cancelled")]
40+
Cancelled,
3541
}
3642

3743
/// Runs the two-player minimax variant on a zero-sum game.
@@ -304,42 +310,61 @@ pub type CollectedMoves<T> = Vec<Result<(<T as Game>::Move, isize), GameSolveErr
304310
///
305311
/// A vector of tuples of the form `(move, score)`.
306312
#[cfg(feature = "rayon")]
307-
pub fn par_move_scores_with_hasher<
308-
T: Game<Player = impl TwoPlayer + Sync + 'static> + Eq + Hash + Sync + Send + 'static,
313+
pub async fn par_move_scores_with_hasher<
314+
T: Game<Player = impl TwoPlayer + Sync + Send + 'static> + Eq + Hash + Sync + Send + 'static,
309315
S,
310316
>(
311317
game: &T,
312-
stats: Option<&Stats<T::Player>>
318+
stats: Option<Arc<Stats<T::Player>>>,
319+
cancellation_token: Option<CancellationToken>
313320
) -> CollectedMoves<T>
314321
where
315322
T::Move: Sync + Send,
316323
T::MoveError: Sync + Send,
317324
S: BuildHasher + Default + Sync + Send + Clone + 'static,
318325
{
326+
use itertools::Itertools;
327+
319328
use crate::transposition::TranspositionCache;
320-
use rayon::prelude::*;
321329
use std::sync::Arc;
322330

323-
// we need to collect it first as we cant parallelize an already non-parallel iterator
324-
let all_moves = game.possible_moves().collect::<Vec<_>>();
325-
let hashmap = Arc::new(TranspositionCache::<T, S>::new());
326-
327-
all_moves
328-
.par_iter()
329-
.map(move |m| {
331+
let result = game.possible_moves().map(|m| {
332+
let m = m.clone();
333+
let game = game.clone();
334+
let cancellation_token = cancellation_token.clone();
335+
let stats = stats.clone();
336+
337+
tokio::spawn(async move {
338+
let hashmap = Arc::new(TranspositionCache::<T, S>::new());
330339
let mut board = game.clone();
331340
board
332-
.make_move(m)
341+
.make_move(&m)
333342
.map_err(|err| GameSolveError::MoveError::<T>(err))?;
334343
// We flip the sign of the score because we want the score from the
335344
// perspective of the player pla`ying the move, not the player whose turn it is.
336345
let mut map = Arc::clone(&hashmap);
337-
Ok((
338-
(*m).clone(),
339-
-solve(&board, &mut map, stats)?,
340-
))
346+
347+
let handle = tokio::spawn(async move {
348+
solve(&board, &mut map, stats.as_deref()).map(|score| -score)
349+
});
350+
351+
if let Some(cancellation_token) = cancellation_token {
352+
tokio::select! {
353+
_ = cancellation_token.cancelled() => {
354+
Err(GameSolveError::Cancelled)
355+
},
356+
result = handle => {
357+
result.unwrap().map(|result| (m, result))
358+
}
359+
}
360+
} else {
361+
handle.await.unwrap().map(|x| (m, x))
362+
}
341363
})
342-
.collect::<Vec<_>>()
364+
})
365+
.collect::<Vec<_>>();
366+
367+
futures::future::join_all(result).await.into_iter().map(|result| result.unwrap()).collect_vec()
343368
}
344369

345370
/// Parallelized version of `move_scores`. (faster by a large margin)
@@ -353,21 +378,22 @@ where
353378
///
354379
/// A vector of tuples of the form `(move, score)`.
355380
#[cfg(feature = "rayon")]
356-
pub fn par_move_scores<
357-
T: Game<Player = impl TwoPlayer + Sync + 'static> + Eq + Hash + Sync + Send + 'static,
381+
pub async fn par_move_scores<
382+
T: Game<Player = impl TwoPlayer + Sync + Send + 'static> + Eq + Hash + Sync + Send + 'static,
358383
>(
359384
game: &T,
360-
stats: Option<&Stats<T::Player>>
385+
stats: Option<Arc<Stats<T::Player>>>,
386+
cancellation_token: Option<CancellationToken>
361387
) -> CollectedMoves<T>
362388
where
363389
T::Move: Sync + Send,
364390
T::MoveError: Sync + Send,
365391
{
366392
if cfg!(feature = "xxhash") {
367393
use twox_hash::RandomXxHashBuilder64;
368-
par_move_scores_with_hasher::<T, RandomXxHashBuilder64>(game, stats)
394+
par_move_scores_with_hasher::<T, RandomXxHashBuilder64>(game, stats, cancellation_token).await
369395
} else {
370396
use std::collections::hash_map::RandomState;
371-
par_move_scores_with_hasher::<T, RandomState>(game, stats)
397+
par_move_scores_with_hasher::<T, RandomState>(game, stats, cancellation_token).await
372398
}
373399
}

crates/game-solver/src/loopy.rs

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::fmt::Debug;
22
use std::hash::Hash;
33
use std::marker::PhantomData;
44

5-
use fxhash::FxHashSet;
5+
use smallvec::SmallVec;
66

77
/// We handle loopy games with a custom struct, `LoopyTracker`, which is a
88
/// HashSet of some state T. This is used to keep track of the states that
@@ -12,9 +12,9 @@ use fxhash::FxHashSet;
1212
///
1313
/// We say `T` is the primary type, and `S` is some representation of `T` without the `LoopyTracker`.
1414
15-
#[derive(Debug, Clone)]
15+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1616
pub struct LoopyTracker<S: Eq + Hash, T: Eq + Hash> {
17-
visited: FxHashSet<T>,
17+
visited: SmallVec<[T; 3]>,
1818
_phantom: PhantomData<S>,
1919
}
2020

@@ -25,55 +25,47 @@ pub trait Loopy<S: Hash + Eq> where Self: Eq + Hash + Sized {
2525
fn without_tracker(&self) -> S;
2626
}
2727

28-
impl<S: Eq + Hash, T: Eq + Hash> LoopyTracker<S, T> {
28+
impl<S: Eq + Hash, T: Eq + Hash + Loopy<S>> LoopyTracker<S, T> {
2929
/// Create a new `LoopyTracker`.
3030
pub fn new() -> Self {
3131
Self {
32-
visited: FxHashSet::default(),
32+
visited: SmallVec::new(),
3333
_phantom: PhantomData,
3434
}
3535
}
3636

3737
/// Check if a state has been visited.
3838
pub fn has_visited(&self, state: &T) -> bool {
39-
self.visited.contains(state)
39+
self.visited.contains(&state)
4040
}
4141

4242
/// Mark a state as visited.
4343
pub fn mark_visited(&mut self, state: T) {
44-
self.visited.insert(state);
44+
self.visited.push(state);
4545
}
4646

4747
/// The number of states visited.
48-
pub fn age(&self) -> usize {
48+
pub fn halfmoves(&self) -> usize {
4949
self.visited.len()
5050
}
51-
}
5251

53-
impl<S: Eq + Hash, T: Eq + Hash> Default for LoopyTracker<S, T> {
54-
fn default() -> Self {
55-
Self::new()
52+
/// This should be called when an irreversible move is made,
53+
/// in place of `Self::mark_visited`.
54+
pub fn clear(&mut self) {
55+
self.visited.clear();
5656
}
5757
}
5858

59-
impl<S: Eq + Hash, T: Eq + Hash + Loopy<S>> PartialEq for LoopyTracker<S, T> {
60-
fn eq(&self, other: &Self) -> bool {
61-
for item in self.visited.iter() {
62-
if !other.visited.contains(item) {
63-
return false;
64-
}
65-
}
66-
67-
return true;
59+
impl<S: Eq + Hash, T: Eq + Hash + Loopy<S>> Default for LoopyTracker<S, T> {
60+
fn default() -> Self {
61+
Self::new()
6862
}
6963
}
7064

71-
impl<S: Eq + Hash, T: Eq + Hash + Loopy<S>> Eq for LoopyTracker<S, T> {}
72-
73-
impl<S: Eq + Hash, T: Eq + Hash + Loopy<S>> Hash for LoopyTracker<S, T> {
74-
fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
75-
for item in self.visited.iter() {
76-
item.without_tracker().hash(hasher);
77-
}
78-
}
79-
}
65+
// impl<S: Eq + Hash, T: Eq + Hash + Loopy<S>> Hash for LoopyTracker<S, T> {
66+
// fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
67+
// for item in self.visited.iter() {
68+
// item.hash(hasher);
69+
// }
70+
// }
71+
// }

crates/games-cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ dialoguer = "0.11.0"
1313
clearscreen = "4.0.1"
1414
owo-colors = "4.1.0"
1515
tokio-util = "0.7.13"
16-
tokio = { version = "1.43.0", features = ["rt", "macros", "rt-multi-thread"] }
16+
tokio = { version = "1.43.0", features = ["rt", "macros", "rt-multi-thread", "signal"] }
1717
ratatui = "0.29.0"

crates/games-cli/src/human.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use std::{
2-
fmt::Display, future::IntoFuture, sync::{
2+
fmt::Display,
3+
future::IntoFuture,
4+
sync::{
35
atomic::{AtomicU64, AtomicUsize, Ordering},
46
Arc,
5-
}, time::Duration
7+
},
8+
time::Duration,
69
};
710

811
use anyhow::Result;
9-
use tokio::select;
10-
use tokio_util::sync::CancellationToken;
1112
use core::hash::Hash;
1213
use game_solver::{
1314
game::Game,
@@ -29,6 +30,8 @@ use ratatui::{
2930
DefaultTerminal, Frame,
3031
};
3132
use std::fmt::Debug;
33+
use tokio::select;
34+
use tokio_util::sync::CancellationToken;
3235

3336
use super::report::{scores::show_scores, stats::show_stats};
3437

@@ -196,12 +199,8 @@ where
196199
let internal_stats = stats.clone();
197200

198201
let game_thread = tokio::spawn(async move {
199-
let game_solving_thread = tokio::spawn(async move {
200-
par_move_scores(
201-
&internal_game,
202-
Some(internal_stats.as_ref()),
203-
)
204-
}).into_future();
202+
let game_solving_thread =
203+
par_move_scores(&internal_game, Some(internal_stats), Some(exit.clone()));
205204

206205
select! {
207206
score = game_solving_thread => {
@@ -218,7 +217,7 @@ where
218217

219218
show_stats::<T>(&stats);
220219
match move_scores {
221-
Some(move_scores) => show_scores(&game, move_scores?),
220+
Some(move_scores) => show_scores(&game, move_scores),
222221
None => eprintln!("Game solving was cancelled!"),
223222
}
224223

crates/games-cli/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub async fn play<
3333
match game.state() {
3434
GameState::Playable => {
3535
if plain {
36-
robotic_output(game);
36+
robotic_output(game).await;
3737
} else {
3838
human_output(game).await.unwrap();
3939
}

crates/games-cli/src/report/scores.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ where
1515
GameSolveError::MoveError(err) => {
1616
eprintln!("Error making move: {:?}", err);
1717
},
18+
GameSolveError::Cancelled => {
19+
eprintln!("Game solving was cancelled!");
20+
},
1821
}
1922
vec![]
2023
});

0 commit comments

Comments
 (0)