diff --git a/econml/tree/_criterion.pxd b/econml/tree/_criterion.pxd index 2eb6b8ba3..c0125d788 100644 --- a/econml/tree/_criterion.pxd +++ b/econml/tree/_criterion.pxd @@ -78,7 +78,9 @@ cdef class Criterion: cdef void node_value(self, double* dest) nogil cdef void node_jacobian(self, double* dest) nogil cdef void node_precond(self, double* dest) nogil - cdef double impurity_improvement(self, double impurity) nogil + cdef double impurity_improvement(self, double impurity_parent, + double impurity_left, + double impurity_right) nogil cdef double proxy_impurity_improvement(self) nogil cdef double min_eig_left(self) nogil cdef double min_eig_right(self) nogil diff --git a/econml/tree/_criterion.pyx b/econml/tree/_criterion.pyx index d0b12a92f..a2f9acdb3 100644 --- a/econml/tree/_criterion.pyx +++ b/econml/tree/_criterion.pyx @@ -209,7 +209,9 @@ cdef class Criterion: return (- self.weighted_n_right * impurity_right - self.weighted_n_left * impurity_left) - cdef double impurity_improvement(self, double impurity) nogil: + cdef double impurity_improvement(self, double impurity_parent, + double impurity_left, + double impurity_right) nogil: """Compute the improvement in impurity This method computes the improvement in impurity when a split occurs. The weighted impurity improvement equation is the following: @@ -218,22 +220,24 @@ cdef class Criterion: where N is the total number of samples, N_t is the number of samples at the current node, N_t_L is the number of samples in the left child, and N_t_R is the number of samples in the right child, + Parameters ---------- - impurity : double - The initial impurity of the node before the split + impurity_parent : float64_t + The initial impurity of the parent node before the split + + impurity_left : float64_t + The impurity of the left child + + impurity_right : float64_t + The impurity of the right child + Return ------ double : improvement in impurity after the split occurs """ - - cdef double impurity_left - cdef double impurity_right - - self.children_impurity(&impurity_left, &impurity_right) - return ((self.weighted_n_node_samples / self.weighted_n_samples) * - (impurity - (self.weighted_n_right / + (impurity_parent - (self.weighted_n_right / self.weighted_n_node_samples * impurity_right) - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) diff --git a/econml/tree/_splitter.pyx b/econml/tree/_splitter.pyx index 557dc6e59..d022c2038 100644 --- a/econml/tree/_splitter.pyx +++ b/econml/tree/_splitter.pyx @@ -619,7 +619,6 @@ cdef class BestSplitter(Splitter): # passed here by the TreeBuilder. The TreeBuilder uses the proxy_node_impurity() to calculate # this baseline if self.is_children_impurity_proxy(), else uses the call to children_impurity() # on the parent node, when that node was split. - best.improvement = self.criterion.impurity_improvement(impurity) # if we need children impurities by the builder, then we populate these entries # otherwise, we leave them blank to avoid the extra computation. if not self.is_children_impurity_proxy(): @@ -630,6 +629,9 @@ cdef class BestSplitter(Splitter): else: best.impurity_left_val = best.impurity_left best.impurity_right_val = best.impurity_right + + best.improvement = self.criterion.impurity_improvement(impurity, + best.impurity_left, best.impurity_right) # Respect invariant for constant features: the original order of # element in features[:n_known_constants] must be preserved for sibling