@@ -102,15 +102,22 @@ use crate::errors::Result;
102
102
use crate :: gpmix:: mixint:: * ;
103
103
use crate :: types:: * ;
104
104
use crate :: EgorConfig ;
105
+ use crate :: EgorState ;
105
106
use crate :: { to_xtypes, EgorSolver } ;
106
107
108
+ use argmin:: core:: observers:: ObserverMode ;
107
109
use egobox_moe:: GpMixtureParams ;
108
110
use log:: info;
109
- use ndarray:: { concatenate, ArrayBase , Axis , Data , Ix2 } ;
111
+ use ndarray:: { concatenate, Array2 , ArrayBase , Axis , Data , Ix2 } ;
110
112
use ndarray_rand:: rand:: SeedableRng ;
111
113
use rand_xoshiro:: Xoshiro256Plus ;
112
114
113
- use argmin:: core:: { Executor , State } ;
115
+ use argmin:: core:: { observers:: Observe , Error , Executor , State , KV } ;
116
+
117
+ /// Json filename for configuration
118
+ pub const CONFIG_FILE : & str = "egor_config.json" ;
119
+ /// Numpy filename for optimization history
120
+ pub const HISTORY_FILE : & str = "egor_history.npy" ;
114
121
115
122
/// EGO optimizer builder allowing to specify function to be minimized
116
123
/// subject to constraints intended to be negative.
@@ -194,13 +201,25 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
194
201
pub fn run ( & self ) -> Result < OptimResult < f64 > > {
195
202
let xtypes = self . solver . config . xtypes . clone ( ) ;
196
203
info ! ( "{:?}" , self . solver. config) ;
204
+ if let Some ( outdir) = self . solver . config . outdir . as_ref ( ) {
205
+ std:: fs:: create_dir_all ( outdir) ?;
206
+ let filepath = std:: path:: Path :: new ( outdir) . join ( CONFIG_FILE ) ;
207
+ let json = serde_json:: to_string ( & self . solver . config ) . unwrap ( ) ;
208
+ std:: fs:: write ( filepath, json) . expect ( "Unable to write file" ) ;
209
+ }
197
210
198
- let result = Executor :: new ( self . fobj . clone ( ) , self . solver . clone ( ) ) . run ( ) ?;
211
+ let exec = Executor :: new ( self . fobj . clone ( ) , self . solver . clone ( ) ) ;
212
+ let result = if let Some ( outdir) = self . solver . config . outdir . as_ref ( ) {
213
+ let hist = OptimizationObserver :: new ( outdir. clone ( ) ) ;
214
+ exec. add_observer ( hist, ObserverMode :: Always ) . run ( ) ?
215
+ } else {
216
+ exec. run ( ) ?
217
+ } ;
199
218
info ! ( "{}" , result) ;
200
219
let ( x_data, y_data) = result. state ( ) . clone ( ) . take_data ( ) . unwrap ( ) ;
201
220
202
221
let res = if !self . solver . config . discrete ( ) {
203
- info ! ( "History : \n {}" , concatenate![ Axis ( 1 ) , x_data, y_data] ) ;
222
+ info ! ( "Data : \n {}" , concatenate![ Axis ( 1 ) , x_data, y_data] ) ;
204
223
OptimResult {
205
224
x_opt : result. state . get_best_param ( ) . unwrap ( ) . to_owned ( ) ,
206
225
y_opt : result. state . get_full_best_cost ( ) . unwrap ( ) . to_owned ( ) ,
@@ -210,7 +229,7 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
210
229
}
211
230
} else {
212
231
let x_data = to_discrete_space ( & xtypes, & x_data. view ( ) ) ;
213
- info ! ( "History : \n {}" , concatenate![ Axis ( 1 ) , x_data, y_data] ) ;
232
+ info ! ( "Data : \n {}" , concatenate![ Axis ( 1 ) , x_data, y_data] ) ;
214
233
215
234
let x_opt = result
216
235
. state
@@ -233,6 +252,73 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
233
252
}
234
253
}
235
254
255
+ // The optimization observer collects best costs ans params
256
+ // during the optimization execution allowing to get optimization history
257
+ // saved as a numpy array for further analysis
258
+ // Note: the observer is activated only when outdir is specified
259
+ #[ derive( Default ) ]
260
+ struct OptimizationObserver {
261
+ pub dir : String ,
262
+ pub best_params : Option < Array2 < f64 > > ,
263
+ pub best_costs : Option < Array2 < f64 > > ,
264
+ }
265
+
266
+ impl OptimizationObserver {
267
+ fn new ( dir : String ) -> Self {
268
+ Self {
269
+ dir,
270
+ best_params : None ,
271
+ best_costs : None ,
272
+ }
273
+ }
274
+ }
275
+
276
+ impl Observe < EgorState < f64 > > for OptimizationObserver {
277
+ fn observe_init (
278
+ & mut self ,
279
+ _name : & str ,
280
+ state : & EgorState < f64 > ,
281
+ _kv : & KV ,
282
+ ) -> std:: result:: Result < ( ) , Error > {
283
+ let bp = state. get_best_param ( ) . unwrap ( ) . to_owned ( ) ;
284
+ self . best_params = Some ( bp. insert_axis ( Axis ( 0 ) ) ) ;
285
+ let bc = state. get_full_best_cost ( ) . unwrap ( ) . to_owned ( ) ;
286
+ self . best_costs = Some ( bc. insert_axis ( Axis ( 0 ) ) ) ;
287
+ Ok ( ( ) )
288
+ }
289
+
290
+ fn observe_iter ( & mut self , state : & EgorState < f64 > , _kv : & KV ) -> std:: result:: Result < ( ) , Error > {
291
+ let bp = state
292
+ . get_best_param ( )
293
+ . unwrap ( )
294
+ . to_owned ( )
295
+ . insert_axis ( Axis ( 0 ) ) ;
296
+ self . best_params = Some ( concatenate ! [ Axis ( 0 ) , self . best_params. take( ) . unwrap( ) , bp] ) ;
297
+ let bc = state
298
+ . get_full_best_cost ( )
299
+ . unwrap ( )
300
+ . to_owned ( )
301
+ . insert_axis ( Axis ( 0 ) ) ;
302
+ self . best_costs = Some ( concatenate ! [ Axis ( 0 ) , self . best_costs. take( ) . unwrap( ) , bc] ) ;
303
+
304
+ Ok ( ( ) )
305
+ }
306
+
307
+ fn observe_final ( & mut self , _state : & EgorState < f64 > ) -> std:: result:: Result < ( ) , Error > {
308
+ let hist = concatenate ! [
309
+ Axis ( 1 ) ,
310
+ self . best_costs. take( ) . unwrap( ) ,
311
+ self . best_params. take( ) . unwrap( ) ,
312
+ ] ;
313
+ std:: fs:: create_dir_all ( & self . dir ) ?;
314
+ let filepath = std:: path:: Path :: new ( & self . dir ) . join ( HISTORY_FILE ) ;
315
+ info ! ( "Save history {:?} in {:?}" , hist. shape( ) , filepath) ;
316
+ info ! ( "History: {}" , hist) ;
317
+ ndarray_npy:: write_npy ( filepath, & hist) . expect ( "Write history" ) ;
318
+ Ok ( ( ) )
319
+ }
320
+ }
321
+
236
322
#[ cfg( test) ]
237
323
mod tests {
238
324
use super :: * ;
@@ -280,7 +366,6 @@ mod tests {
280
366
let expected = array ! [ -15.1 ] ;
281
367
assert_abs_diff_eq ! ( expected, res. y_opt, epsilon = 0.5 ) ;
282
368
let saved_doe: Array2 < f64 > = read_npy ( & outfile) . unwrap ( ) ;
283
- let _ = std:: fs:: remove_file ( & outfile) ;
284
369
assert_abs_diff_eq ! ( initial_doe, saved_doe. slice( s![ ..3 , ..1 ] ) , epsilon = 1e-6 ) ;
285
370
}
286
371
0 commit comments