Skip to content

Commit caa8b54

Browse files
authored
Merge pull request #1071 from PowerGridModel/feature/remove-calc-fn-from-job-dispatch
Clean-up main model: make job dispatch calculation agnostic
2 parents 0187d85 + 8287cb8 commit caa8b54

File tree

5 files changed

+99
-83
lines changed

5 files changed

+99
-83
lines changed

power_grid_model_c/power_grid_model/include/power_grid_model/job_adapter.hpp

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,22 @@ template <class MainModel, class... ComponentType>
2121
class JobDispatchAdapter<MainModel, ComponentList<ComponentType...>>
2222
: public JobDispatchInterface<JobDispatchAdapter<MainModel, ComponentList<ComponentType...>>> {
2323
public:
24-
JobDispatchAdapter(std::reference_wrapper<MainModel> model) : model_{std::move(model)} {}
24+
JobDispatchAdapter(std::reference_wrapper<MainModel> model_reference,
25+
std::reference_wrapper<MainModelOptions const> options)
26+
: model_reference_{model_reference}, options_{options} {}
2527
JobDispatchAdapter(JobDispatchAdapter const& other)
26-
: model_copy_{std::make_unique<MainModel>(other.model_.get())},
27-
model_{std::ref(*model_copy_)},
28+
: model_copy_{std::make_unique<MainModel>(other.model_reference_.get())},
29+
model_reference_{std::ref(*model_copy_)},
30+
options_{std::ref(other.options_)},
2831
components_to_update_{other.components_to_update_},
2932
update_independence_{other.update_independence_},
3033
independence_flags_{other.independence_flags_},
3134
all_scenarios_sequence_{other.all_scenarios_sequence_} {}
3235
JobDispatchAdapter& operator=(JobDispatchAdapter const& other) {
3336
if (this != &other) {
34-
model_copy_ = std::make_unique<MainModel>(other.model_.get());
35-
model_ = std::ref(*model_copy_);
37+
model_copy_ = std::make_unique<MainModel>(other.model_reference_.get());
38+
model_reference_ = std::ref(*model_copy_);
39+
options_ = std::ref(other.options_);
3640
components_to_update_ = other.components_to_update_;
3741
update_independence_ = other.update_independence_;
3842
independence_flags_ = other.independence_flags_;
@@ -42,15 +46,17 @@ class JobDispatchAdapter<MainModel, ComponentList<ComponentType...>>
4246
}
4347
JobDispatchAdapter(JobDispatchAdapter&& other) noexcept
4448
: model_copy_{std::move(other.model_copy_)},
45-
model_{model_copy_ ? std::ref(*model_copy_) : std::move(other.model_)},
49+
model_reference_{model_copy_ ? std::ref(*model_copy_) : std::move(other.model_reference_)},
50+
options_{other.options_},
4651
components_to_update_{std::move(other.components_to_update_)},
4752
update_independence_{std::move(other.update_independence_)},
4853
independence_flags_{std::move(other.independence_flags_)},
4954
all_scenarios_sequence_{std::move(other.all_scenarios_sequence_)} {}
5055
JobDispatchAdapter& operator=(JobDispatchAdapter&& other) noexcept {
5156
if (this != &other) {
5257
model_copy_ = std::move(other.model_copy_);
53-
model_ = model_copy_ ? std::ref(*model_copy_) : std::move(other.model_);
58+
model_reference_ = model_copy_ ? std::ref(*model_copy_) : std::move(other.model_reference_);
59+
options_ = other.options_;
5460
components_to_update_ = std::move(other.components_to_update_);
5561
update_independence_ = std::move(other.update_independence_);
5662
independence_flags_ = std::move(other.independence_flags_);
@@ -66,10 +72,9 @@ class JobDispatchAdapter<MainModel, ComponentList<ComponentType...>>
6672
// to call derived-class implementation details as part of the CRTP pattern.
6773
friend class JobDispatchInterface<JobDispatchAdapter>;
6874

69-
static constexpr Idx ignore_output{-1};
70-
7175
std::unique_ptr<MainModel> model_copy_;
72-
std::reference_wrapper<MainModel> model_;
76+
std::reference_wrapper<MainModel> model_reference_;
77+
std::reference_wrapper<MainModelOptions const> options_;
7378

7479
main_core::utils::ComponentFlags<ComponentType...> components_to_update_{};
7580
main_core::update::independence::UpdateIndependence<ComponentType...> update_independence_{};
@@ -80,26 +85,22 @@ class JobDispatchAdapter<MainModel, ComponentList<ComponentType...>>
8085

8186
std::mutex calculation_info_mutex_;
8287

83-
// TODO(figueroa1395): Keep calculation_fn at the adapter level only
84-
template <typename Calculate>
85-
requires std::invocable<std::remove_cvref_t<Calculate>, MainModel&, MutableDataset const&, bool>
86-
void calculate_impl(Calculate&& calculation_fn, MutableDataset const& result_data, Idx scenario_idx) const {
87-
std::forward<Calculate>(calculation_fn)(model_.get(), result_data.get_individual_scenario(scenario_idx), false);
88+
void calculate_impl(MutableDataset const& result_data, Idx scenario_idx) const {
89+
MainModel::calculator(options_.get(), model_reference_.get(), result_data.get_individual_scenario(scenario_idx),
90+
false);
8891
}
8992

90-
template <typename Calculate>
91-
requires std::invocable<std::remove_cvref_t<Calculate>, MainModel&, MutableDataset const&, Idx>
92-
void cache_calculate_impl(Calculate&& calculation_fn) const {
93+
void cache_calculate_impl() const {
9394
// calculate once to cache topology, ignore results, all math solvers are initialized
9495
try {
95-
std::forward<Calculate>(calculation_fn)(model_.get(),
96-
{
97-
false,
98-
1,
99-
"sym_output",
100-
model_.get().meta_data(),
101-
},
102-
true);
96+
MainModel::calculator(options_.get(), model_reference_.get(),
97+
{
98+
false,
99+
1,
100+
"sym_output",
101+
model_reference_.get().meta_data(),
102+
},
103+
true);
103104
} catch (SparseMatrixError const&) { // NOLINT(bugprone-empty-catch) // NOSONAR
104105
// missing entries are provided in the update data
105106
} catch (NotObservableError const&) { // NOLINT(bugprone-empty-catch) // NOSONAR
@@ -110,33 +111,35 @@ class JobDispatchAdapter<MainModel, ComponentList<ComponentType...>>
110111
void prepare_job_dispatch_impl(ConstDataset const& update_data) {
111112
// cache component update order where possible.
112113
// the order for a cacheable (independent) component by definition is the same across all scenarios
113-
components_to_update_ = model_.get().get_components_to_update(update_data);
114+
components_to_update_ = model_reference_.get().get_components_to_update(update_data);
114115
update_independence_ = main_core::update::independence::check_update_independence<ComponentType...>(
115-
model_.get().state(), update_data);
116+
model_reference_.get().state(), update_data);
116117
std::ranges::transform(update_independence_, independence_flags_.begin(),
117118
[](auto const& comp) { return comp.is_independent(); });
118119
all_scenarios_sequence_ = std::make_shared<main_core::utils::SequenceIdx<ComponentType...>>(
119120
main_core::update::get_all_sequence_idx_map<ComponentType...>(
120-
model_.get().state(), update_data, 0, components_to_update_, update_independence_, false));
121+
model_reference_.get().state(), update_data, 0, components_to_update_, update_independence_, false));
121122
}
122123

123124
void setup_impl(ConstDataset const& update_data, Idx scenario_idx) {
124125
current_scenario_sequence_cache_ = main_core::update::get_all_sequence_idx_map<ComponentType...>(
125-
model_.get().state(), update_data, scenario_idx, components_to_update_, update_independence_, true);
126+
model_reference_.get().state(), update_data, scenario_idx, components_to_update_, update_independence_,
127+
true);
126128
auto const current_scenario_sequence = get_current_scenario_sequence_view_();
127-
model_.get().template update_components<cached_update_t>(update_data, scenario_idx, current_scenario_sequence);
129+
model_reference_.get().template update_components<cached_update_t>(update_data, scenario_idx,
130+
current_scenario_sequence);
128131
}
129132

130133
void winddown_impl() {
131-
model_.get().restore_components(get_current_scenario_sequence_view_());
134+
model_reference_.get().restore_components(get_current_scenario_sequence_view_());
132135
std::ranges::for_each(current_scenario_sequence_cache_, [](auto& comp_seq_idx) { comp_seq_idx.clear(); });
133136
}
134137

135-
CalculationInfo get_calculation_info_impl() const { return model_.get().calculation_info(); }
138+
CalculationInfo get_calculation_info_impl() const { return model_reference_.get().calculation_info(); }
136139

137140
void thread_safe_add_calculation_info_impl(CalculationInfo const& info) {
138141
std::lock_guard const lock{calculation_info_mutex_};
139-
model_.get().merge_calculation_info(info);
142+
model_reference_.get().merge_calculation_info(info);
140143
}
141144

142145
auto get_current_scenario_sequence_view_() const {

power_grid_model_c/power_grid_model/include/power_grid_model/job_dispatch.hpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,22 @@
88
#include "job_interface.hpp"
99

1010
#include "main_core/calculation_info.hpp"
11-
#include "main_core/update.hpp"
1211

1312
#include <thread>
1413

1514
namespace power_grid_model {
1615

1716
class JobDispatch {
1817
public:
19-
static constexpr Idx ignore_output{-1};
2018
static constexpr Idx sequential{-1};
2119

22-
// TODO(figueroa1395): remove calculation_fn dependency
23-
// TODO(figueroa1395): add concept to Adapter template parameter
2420
// TODO(figueroa1395): add generic template parameters for update_data and result_data
25-
template <typename Adapter, typename Calculate>
26-
static BatchParameter batch_calculation(Adapter& adapter, Calculate&& calculation_fn,
27-
MutableDataset const& result_data, ConstDataset const& update_data,
28-
Idx threading = sequential) {
21+
template <typename Adapter, typename ResultDataset, typename UpdateDataset>
22+
requires std::is_base_of_v<JobDispatchInterface<Adapter>, Adapter>
23+
static BatchParameter batch_calculation(Adapter& adapter, ResultDataset const& result_data,
24+
UpdateDataset const& update_data, Idx threading = sequential) {
2925
if (update_data.empty()) {
30-
adapter.calculate(std::forward<Calculate>(calculation_fn), result_data);
26+
adapter.calculate(result_data);
3127
return BatchParameter{};
3228
}
3329

@@ -41,14 +37,13 @@ class JobDispatch {
4137
}
4238

4339
// calculate once to cache, ignore results
44-
adapter.cache_calculate(std::forward<Calculate>(calculation_fn)); // TODO(figueroa1395): Time this step
40+
adapter.cache_calculate();
4541

4642
// error messages
4743
std::vector<std::string> exceptions(n_scenarios, "");
4844

4945
adapter.prepare_job_dispatch(update_data);
50-
auto single_job =
51-
single_thread_job(adapter, std::forward<Calculate>(calculation_fn), result_data, update_data, exceptions);
46+
auto single_job = single_thread_job(adapter, result_data, update_data, exceptions);
5247

5348
job_dispatch(single_job, n_scenarios, threading);
5449

@@ -58,11 +53,10 @@ class JobDispatch {
5853
}
5954

6055
private:
61-
template <typename Adapter, typename Calculate>
62-
static auto single_thread_job(Adapter& base_adapter, Calculate&& calculation_fn, MutableDataset const& result_data,
63-
ConstDataset const& update_data, std::vector<std::string>& exceptions) {
64-
return [&base_adapter, &exceptions, calculation_fn_ = std::forward<Calculate>(calculation_fn), &result_data,
65-
&update_data](Idx start, Idx stride, Idx n_scenarios) {
56+
template <typename Adapter, typename ResultDataset, typename UpdateDataset>
57+
static auto single_thread_job(Adapter& base_adapter, ResultDataset const& result_data,
58+
UpdateDataset const& update_data, std::vector<std::string>& exceptions) {
59+
return [&base_adapter, &exceptions, &result_data, &update_data](Idx start, Idx stride, Idx n_scenarios) {
6660
assert(n_scenarios <= narrow_cast<Idx>(exceptions.size()));
6761

6862
CalculationInfo thread_info;
@@ -92,8 +86,8 @@ class JobDispatch {
9286
adapter = copy_adapter_functor();
9387
};
9488

95-
auto run = [&adapter, &calculation_fn_, &result_data, &thread_info](Idx scenario_idx) {
96-
adapter.calculate(calculation_fn_, result_data, scenario_idx);
89+
auto run = [&adapter, &result_data, &thread_info](Idx scenario_idx) {
90+
adapter.calculate(result_data, scenario_idx);
9791
main_core::merge_into(thread_info, adapter.get_calculation_info());
9892
};
9993

power_grid_model_c/power_grid_model/include/power_grid_model/job_interface.hpp

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,63 @@
1717
namespace power_grid_model {
1818
template <typename Adapter> class JobDispatchInterface {
1919
public:
20-
template <typename Calculate, typename ResultDataset>
21-
requires requires(Adapter& adapter, Calculate&& calculation_fn, ResultDataset const& result_data,
22-
Idx scenario_idx) {
23-
{
24-
adapter.calculate_impl(std::forward<Calculate>(calculation_fn), result_data, scenario_idx)
25-
} -> std::same_as<void>;
20+
// the multiple NOSONARs are used to avoid the complaints about the unnamed concepts
21+
template <typename ResultDataset>
22+
void calculate(ResultDataset const& result_data, Idx pos = 0)
23+
requires requires(Adapter& adapter) { // NOSONAR
24+
{ adapter.calculate_impl(result_data, pos) } -> std::same_as<void>;
2625
}
27-
void calculate(Calculate&& calculation_fn, ResultDataset const& result_data, Idx scenario_idx = 0) {
28-
return static_cast<Adapter*>(this)->calculate_impl(std::forward<Calculate>(calculation_fn), result_data,
29-
scenario_idx);
26+
{
27+
return static_cast<Adapter*>(this)->calculate_impl(result_data, pos);
3028
}
3129

32-
template <typename Calculate>
33-
requires requires(Adapter& adapter, Calculate&& calculation_fn) {
34-
{ adapter.cache_calculate_impl(std::forward<Calculate>(calculation_fn)) } -> std::same_as<void>;
30+
void cache_calculate()
31+
requires requires(Adapter& adapter) { // NOSONAR
32+
{ adapter.cache_calculate_impl() } -> std::same_as<void>;
3533
}
36-
void cache_calculate(Calculate&& calculation_fn) {
37-
return static_cast<Adapter*>(this)->cache_calculate_impl(std::forward<Calculate>(calculation_fn));
34+
{
35+
return static_cast<Adapter*>(this)->cache_calculate_impl();
3836
}
3937

40-
template <typename UpdateDataset> void prepare_job_dispatch(UpdateDataset const& update_data) {
38+
template <typename UpdateDataset>
39+
void prepare_job_dispatch(UpdateDataset const& update_data)
40+
requires requires(Adapter& adapter) { // NOSONAR
41+
{ adapter.prepare_job_dispatch_impl(update_data) } -> std::same_as<void>;
42+
}
43+
{
4144
return static_cast<Adapter*>(this)->prepare_job_dispatch_impl(update_data);
4245
}
4346

44-
template <typename UpdateDataset> void setup(UpdateDataset const& update_data, Idx scenario_idx) {
47+
template <typename UpdateDataset>
48+
void setup(UpdateDataset const& update_data, Idx scenario_idx)
49+
requires requires(Adapter& adapter) { // NOSONAR
50+
{ adapter.setup_impl(update_data, scenario_idx) } -> std::same_as<void>;
51+
}
52+
{
4553
return static_cast<Adapter*>(this)->setup_impl(update_data, scenario_idx);
4654
}
4755

48-
void winddown() { return static_cast<Adapter*>(this)->winddown_impl(); }
56+
void winddown()
57+
requires requires(Adapter& adapter) { // NOSONAR
58+
{ adapter.winddown_impl() } -> std::same_as<void>;
59+
}
60+
{
61+
return static_cast<Adapter*>(this)->winddown_impl();
62+
}
4963

50-
CalculationInfo get_calculation_info() const {
64+
CalculationInfo get_calculation_info() const
65+
requires requires(Adapter& adapter) { // NOSONAR
66+
{ adapter.get_calculation_info_impl() } -> std::same_as<CalculationInfo>;
67+
}
68+
{
5169
return static_cast<const Adapter*>(this)->get_calculation_info_impl();
5270
}
5371

54-
void thread_safe_add_calculation_info(CalculationInfo const& info) {
72+
void thread_safe_add_calculation_info(CalculationInfo const& info)
73+
requires requires(Adapter& adapter) { // NOSONAR
74+
{ adapter.thread_safe_add_calculation_info_impl(info) } -> std::same_as<void>;
75+
}
76+
{
5577
static_cast<Adapter*>(this)->thread_safe_add_calculation_info_impl(info);
5678
}
5779

power_grid_model_c/power_grid_model/include/power_grid_model/main_model.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class MainModel {
6969
Run the calculation function in batch on the provided update data.
7070
7171
The calculation function should be able to run standalone.
72-
It should output to the provided result_data if the trailing argument is not ignore_output.
7372
7473
threading
7574
< 0 sequential
@@ -79,9 +78,8 @@ class MainModel {
7978
*/
8079
BatchParameter calculate(Options const& options, MutableDataset const& result_data,
8180
ConstDataset const& update_data) {
82-
JobDispatchAdapter<Impl, AllComponents> adapter{std::ref(impl())};
83-
return JobDispatch::batch_calculation(adapter, Impl::calculator(options), result_data, update_data,
84-
options.threading);
81+
JobDispatchAdapter<Impl, AllComponents> adapter{std::ref(impl()), std::ref(options)};
82+
return JobDispatch::batch_calculation(adapter, result_data, update_data, options.threading);
8583
}
8684

8785
CalculationInfo calculation_info() const { return impl().calculation_info(); }

power_grid_model_c/power_grid_model/include/power_grid_model/main_model_impl.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,14 +544,13 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
544544
}
545545

546546
public:
547-
static auto calculator(Options const& options) {
548-
return [options](MainModelImpl& model, MutableDataset const& target_data, bool cache_run) {
549-
auto sub_opt = options; // copy
550-
sub_opt.err_tol = cache_run ? std::numeric_limits<double>::max() : options.err_tol;
551-
sub_opt.max_iter = cache_run ? 1 : options.max_iter;
547+
static auto calculator(Options const& options, MainModelImpl& model, MutableDataset const& target_data,
548+
bool cache_run) {
549+
auto sub_opt = options; // copy
550+
sub_opt.err_tol = cache_run ? std::numeric_limits<double>::max() : options.err_tol;
551+
sub_opt.max_iter = cache_run ? 1 : options.max_iter;
552552

553-
model.calculate(sub_opt, target_data);
554-
};
553+
model.calculate(sub_opt, target_data);
555554
}
556555

557556
CalculationInfo calculation_info() const { return calculation_info_; }

0 commit comments

Comments
 (0)