Skip to content
Open
4 changes: 3 additions & 1 deletion econml/tree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ cdef class Criterion:
cdef void node_value(self, double* dest) nogil
cdef void node_jacobian(self, double* dest) nogil
cdef void node_precond(self, double* dest) nogil
cdef double impurity_improvement(self, double impurity) nogil
cdef double impurity_improvement(self, double impurity_parent,
double impurity_left,
double impurity_right) nogil
cdef double proxy_impurity_improvement(self) nogil
cdef double min_eig_left(self) nogil
cdef double min_eig_right(self) nogil
Expand Down
24 changes: 14 additions & 10 deletions econml/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ cdef class Criterion:
return (- self.weighted_n_right * impurity_right
- self.weighted_n_left * impurity_left)

cdef double impurity_improvement(self, double impurity) nogil:
cdef double impurity_improvement(self, double impurity_parent,
double impurity_left,
double impurity_right) nogil:
"""Compute the improvement in impurity
This method computes the improvement in impurity when a split occurs.
The weighted impurity improvement equation is the following:
Expand All @@ -218,22 +220,24 @@ cdef class Criterion:
where N is the total number of samples, N_t is the number of samples
at the current node, N_t_L is the number of samples in the left child,
and N_t_R is the number of samples in the right child,

Parameters
----------
impurity : double
The initial impurity of the node before the split
impurity_parent : float64_t
The initial impurity of the parent node before the split

impurity_left : float64_t
The impurity of the left child

impurity_right : float64_t
The impurity of the right child

Return
------
double : improvement in impurity after the split occurs
"""

cdef double impurity_left
cdef double impurity_right

self.children_impurity(&impurity_left, &impurity_right)

return ((self.weighted_n_node_samples / self.weighted_n_samples) *
(impurity - (self.weighted_n_right /
(impurity_parent - (self.weighted_n_right /
self.weighted_n_node_samples * impurity_right)
- (self.weighted_n_left /
self.weighted_n_node_samples * impurity_left)))
Expand Down
4 changes: 3 additions & 1 deletion econml/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ cdef class BestSplitter(Splitter):
# passed here by the TreeBuilder. The TreeBuilder uses the proxy_node_impurity() to calculate
# this baseline if self.is_children_impurity_proxy(), else uses the call to children_impurity()
# on the parent node, when that node was split.
best.improvement = self.criterion.impurity_improvement(impurity)
# if we need children impurities by the builder, then we populate these entries
# otherwise, we leave them blank to avoid the extra computation.
if not self.is_children_impurity_proxy():
Expand All @@ -630,6 +629,9 @@ cdef class BestSplitter(Splitter):
else:
best.impurity_left_val = best.impurity_left
best.impurity_right_val = best.impurity_right

best.improvement = self.criterion.impurity_improvement(impurity,
best.impurity_left, best.impurity_right)

# Respect invariant for constant features: the original order of
# element in features[:n_known_constants] must be preserved for sibling
Expand Down