Skip to content

Commit 954d24a

Browse files
authored
Save config and optim history (#190)
1 parent 062cd28 commit 954d24a

File tree

2 files changed

+92
-7
lines changed

2 files changed

+92
-7
lines changed

ego/src/egor.rs

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,22 @@ use crate::errors::Result;
102102
use crate::gpmix::mixint::*;
103103
use crate::types::*;
104104
use crate::EgorConfig;
105+
use crate::EgorState;
105106
use crate::{to_xtypes, EgorSolver};
106107

108+
use argmin::core::observers::ObserverMode;
107109
use egobox_moe::GpMixtureParams;
108110
use log::info;
109-
use ndarray::{concatenate, ArrayBase, Axis, Data, Ix2};
111+
use ndarray::{concatenate, Array2, ArrayBase, Axis, Data, Ix2};
110112
use ndarray_rand::rand::SeedableRng;
111113
use rand_xoshiro::Xoshiro256Plus;
112114

113-
use argmin::core::{Executor, State};
115+
use argmin::core::{observers::Observe, Error, Executor, State, KV};
116+
117+
/// Json filename for configuration
118+
pub const CONFIG_FILE: &str = "egor_config.json";
119+
/// Numpy filename for optimization history
120+
pub const HISTORY_FILE: &str = "egor_history.npy";
114121

115122
/// EGO optimizer builder allowing to specify function to be minimized
116123
/// subject to constraints intended to be negative.
@@ -194,13 +201,25 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
194201
pub fn run(&self) -> Result<OptimResult<f64>> {
195202
let xtypes = self.solver.config.xtypes.clone();
196203
info!("{:?}", self.solver.config);
204+
if let Some(outdir) = self.solver.config.outdir.as_ref() {
205+
std::fs::create_dir_all(outdir)?;
206+
let filepath = std::path::Path::new(outdir).join(CONFIG_FILE);
207+
let json = serde_json::to_string(&self.solver.config).unwrap();
208+
std::fs::write(filepath, json).expect("Unable to write file");
209+
}
197210

198-
let result = Executor::new(self.fobj.clone(), self.solver.clone()).run()?;
211+
let exec = Executor::new(self.fobj.clone(), self.solver.clone());
212+
let result = if let Some(outdir) = self.solver.config.outdir.as_ref() {
213+
let hist = OptimizationObserver::new(outdir.clone());
214+
exec.add_observer(hist, ObserverMode::Always).run()?
215+
} else {
216+
exec.run()?
217+
};
199218
info!("{}", result);
200219
let (x_data, y_data) = result.state().clone().take_data().unwrap();
201220

202221
let res = if !self.solver.config.discrete() {
203-
info!("History: \n{}", concatenate![Axis(1), x_data, y_data]);
222+
info!("Data: \n{}", concatenate![Axis(1), x_data, y_data]);
204223
OptimResult {
205224
x_opt: result.state.get_best_param().unwrap().to_owned(),
206225
y_opt: result.state.get_full_best_cost().unwrap().to_owned(),
@@ -210,7 +229,7 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
210229
}
211230
} else {
212231
let x_data = to_discrete_space(&xtypes, &x_data.view());
213-
info!("History: \n{}", concatenate![Axis(1), x_data, y_data]);
232+
info!("Data: \n{}", concatenate![Axis(1), x_data, y_data]);
214233

215234
let x_opt = result
216235
.state
@@ -233,6 +252,73 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
233252
}
234253
}
235254

255+
// The optimization observer collects best costs ans params
256+
// during the optimization execution allowing to get optimization history
257+
// saved as a numpy array for further analysis
258+
// Note: the observer is activated only when outdir is specified
259+
#[derive(Default)]
260+
struct OptimizationObserver {
261+
pub dir: String,
262+
pub best_params: Option<Array2<f64>>,
263+
pub best_costs: Option<Array2<f64>>,
264+
}
265+
266+
impl OptimizationObserver {
267+
fn new(dir: String) -> Self {
268+
Self {
269+
dir,
270+
best_params: None,
271+
best_costs: None,
272+
}
273+
}
274+
}
275+
276+
impl Observe<EgorState<f64>> for OptimizationObserver {
277+
fn observe_init(
278+
&mut self,
279+
_name: &str,
280+
state: &EgorState<f64>,
281+
_kv: &KV,
282+
) -> std::result::Result<(), Error> {
283+
let bp = state.get_best_param().unwrap().to_owned();
284+
self.best_params = Some(bp.insert_axis(Axis(0)));
285+
let bc = state.get_full_best_cost().unwrap().to_owned();
286+
self.best_costs = Some(bc.insert_axis(Axis(0)));
287+
Ok(())
288+
}
289+
290+
fn observe_iter(&mut self, state: &EgorState<f64>, _kv: &KV) -> std::result::Result<(), Error> {
291+
let bp = state
292+
.get_best_param()
293+
.unwrap()
294+
.to_owned()
295+
.insert_axis(Axis(0));
296+
self.best_params = Some(concatenate![Axis(0), self.best_params.take().unwrap(), bp]);
297+
let bc = state
298+
.get_full_best_cost()
299+
.unwrap()
300+
.to_owned()
301+
.insert_axis(Axis(0));
302+
self.best_costs = Some(concatenate![Axis(0), self.best_costs.take().unwrap(), bc]);
303+
304+
Ok(())
305+
}
306+
307+
fn observe_final(&mut self, _state: &EgorState<f64>) -> std::result::Result<(), Error> {
308+
let hist = concatenate![
309+
Axis(1),
310+
self.best_costs.take().unwrap(),
311+
self.best_params.take().unwrap(),
312+
];
313+
std::fs::create_dir_all(&self.dir)?;
314+
let filepath = std::path::Path::new(&self.dir).join(HISTORY_FILE);
315+
info!("Save history {:?} in {:?}", hist.shape(), filepath);
316+
info!("History: {}", hist);
317+
ndarray_npy::write_npy(filepath, &hist).expect("Write history");
318+
Ok(())
319+
}
320+
}
321+
236322
#[cfg(test)]
237323
mod tests {
238324
use super::*;
@@ -280,7 +366,6 @@ mod tests {
280366
let expected = array![-15.1];
281367
assert_abs_diff_eq!(expected, res.y_opt, epsilon = 0.5);
282368
let saved_doe: Array2<f64> = read_npy(&outfile).unwrap();
283-
let _ = std::fs::remove_file(&outfile);
284369
assert_abs_diff_eq!(initial_doe, saved_doe.slice(s![..3, ..1]), epsilon = 1e-6);
285370
}
286371

ego/src/solver/egor_solver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ use std::time::Instant;
125125

126126
/// Numpy filename for initial DOE dump
127127
pub const DOE_INITIAL_FILE: &str = "egor_initial_doe.npy";
128-
/// Numpy Filename for current DOE dump
128+
/// Numpy filename for current DOE dump
129129
pub const DOE_FILE: &str = "egor_doe.npy";
130130

131131
/// Default tolerance value for constraints to be satisfied (ie cstr < tol)

0 commit comments

Comments
 (0)