Skip to content

Commit a9698f7

Browse files
authored
fix: divide 64 in packages (#4)
Signed-off-by: Keming <[email protected]>
1 parent 1f1987a commit a9698f7

File tree

6 files changed

+39
-31
lines changed

6 files changed

+39
-31
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
members = ["crates/*"]
33

44
[workspace.package]
5-
version = "0.2.1"
5+
version = "0.2.2"
66
edition = "2021"
77
description = "A Rust implementation of the RaBitQ vector search algorithm."
88
license = "AGPL-3.0"

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
packages := cli disk service
2+
13
build:
24
cargo b
35

46
format:
57
@cargo +nightly fmt
8+
@$(foreach package, $(packages), cargo +nightly fmt --package $(package);)
69

710
lint:
8-
@cargo +nightly fmt -- --check
11+
@cargo +nightly fmt --check
12+
@$(foreach package, $(packages), cargo +nightly fmt --package $(package) --check;)
913
@cargo clippy -- -D warnings
14+
@$(foreach package, $(packages), cargo clippy --package $(package) -- -D warnings;)
1015

1116
test:
1217
@cargo test

crates/disk/src/cache.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl CachedVector {
7979
let s3_client = Arc::new(Client::new(&s3_config));
8080
let num_per_block = BLOCK_BYTE_LIMIT / (4 * (dim + 1));
8181
let total_num = num;
82-
let total_block = (total_num + num_per_block - 1) / num_per_block;
82+
let total_block = total_num.div_ceil(num_per_block);
8383
let sqlite_conn = Connection::open(Path::new(&local_path)).expect("failed to open sqlite");
8484
sqlite_conn
8585
.execute(

crates/disk/src/disk.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ use crate::cache::CachedVector;
1818

1919
/// Rank with cached raw vectors.
2020
#[derive(Debug)]
21-
pub struct CacheReRanker {
21+
pub struct CacheReRanker<'a> {
2222
threshold: f32,
2323
topk: usize,
2424
heap: BinaryHeap<(Ord32, AlwaysEqual<u32>)>,
25-
query: Vec<f32>,
25+
query: &'a [f32],
2626
}
2727

28-
impl CacheReRanker {
29-
fn new(query: Vec<f32>, topk: usize) -> Self {
28+
impl<'a> CacheReRanker<'a> {
29+
fn new(query: &'a [f32], topk: usize) -> Self {
3030
Self {
3131
threshold: f32::MAX,
3232
query,
@@ -45,7 +45,7 @@ impl CacheReRanker {
4545
for &(rough, u) in rough_distances.iter() {
4646
if rough < self.threshold {
4747
let accurate = cache
48-
.get_query_vec_distance(&self.query, u)
48+
.get_query_vec_distance(self.query, u)
4949
.await
5050
.expect("failed to get distance");
5151
precise += 1;
@@ -142,11 +142,16 @@ impl DiskRaBitQ {
142142

143143
/// Query the topk nearest neighbors for the given query asynchronously.
144144
pub async fn query(&self, query: Vec<f32>, probe: usize, topk: usize) -> Vec<(f32, u32)> {
145-
assert_eq!(self.dim as usize, query.len());
146-
let y_projected = project(&query, &self.orthogonal.as_ref());
145+
assert_eq!(self.dim as usize, query.len().div_ceil(64) * 64);
146+
// padding
147+
let mut query_vec = query.to_vec();
148+
if query.len() < self.dim as usize {
149+
query_vec.extend_from_slice(&vec![0.0; self.dim as usize - query.len()]);
150+
}
151+
152+
let y_projected = project(&query_vec, &self.orthogonal.as_ref());
147153
let k = self.centroids.shape().1;
148154
let mut lists = Vec::with_capacity(k);
149-
let mut residual = vec![0f32; self.dim as usize];
150155
for (i, centroid) in self.centroids.col_iter().enumerate() {
151156
let dist = l2_squared_distance(
152157
centroid
@@ -161,10 +166,11 @@ impl DiskRaBitQ {
161166
lists.truncate(length);
162167
lists.sort_by(|a, b| a.0.total_cmp(&b.0));
163168

164-
let mut re_ranker = CacheReRanker::new(query, topk);
169+
let mut re_ranker = CacheReRanker::new(&query_vec, topk);
170+
let mut residual = vec![0f32; self.dim as usize];
171+
let mut quantized = vec![0u8; (self.dim as usize).div_ceil(64) * 64];
165172
let mut rough_distances = Vec::new();
166-
let mut quantized = vec![0u8; self.dim as usize];
167-
let mut binary_vec = vec![0u64; self.dim as usize * THETA_LOG_DIM as usize / 64];
173+
let mut binary_vec = vec![0u64; self.dim.div_ceil(64) as usize * THETA_LOG_DIM as usize];
168174
for &(dist, i) in lists[..length].iter() {
169175
let (lower_bound, upper_bound) =
170176
min_max_residual(&mut residual, &y_projected.as_ref(), &self.centroids.col(i));

crates/service/src/main.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,15 @@ mod args;
1818
async fn shutdown_signal() {
1919
let mut interrupt = signal(SignalKind::interrupt()).unwrap();
2020
let mut terminate = signal(SignalKind::terminate()).unwrap();
21-
loop {
22-
tokio::select! {
23-
_ = interrupt.recv() => {
24-
info!("Received interrupt signal");
25-
break;
26-
}
27-
_ = terminate.recv() => {
28-
info!("Received terminate signal");
29-
break;
30-
}
31-
};
32-
}
21+
tokio::select! {
22+
_ = interrupt.recv() => {
23+
info!("Received interrupt signal");
24+
}
25+
_ = terminate.recv() => {
26+
info!("Received terminate signal");
27+
}
28+
};
29+
info!("Shutting down");
3330
}
3431

3532
async fn health_check() -> impl IntoResponse {
@@ -75,7 +72,7 @@ async fn main() {
7572

7673
let config: args::Args = argh::from_env();
7774
let model_path = Path::new(&config.dir);
78-
download_meta_from_s3(&config.bucket, &config.key, &model_path)
75+
download_meta_from_s3(&config.bucket, &config.key, model_path)
7976
.await
8077
.expect("failed to download meta");
8178
let rabitq =

0 commit comments

Comments
 (0)