@@ -41,9 +41,9 @@ function _convert(
41
41
) where {S, T}
42
42
43
43
if node. is_leaf
44
- featfreq = Tuple (sum (labels[node. region] .== l) for l in list)
44
+ classfreq = Tuple (sum (labels[node. region] .== l) for l in list)
45
45
return Leaf {T, length(list)} (
46
- Tuple (list), argmax (featfreq ), featfreq , length (node. region))
46
+ Tuple (list), argmax (classfreq ), classfreq , length (node. region))
47
47
else
48
48
left = _convert (node. l, list, labels)
49
49
right = _convert (node. r, list, labels)
@@ -117,6 +117,7 @@ function build_stump(
117
117
labels :: AbstractVector{T} ,
118
118
features :: AbstractMatrix{S} ,
119
119
weights = nothing ;
120
+ n_classes :: Int = length (unique (labels)),
120
121
rng = Random. GLOBAL_RNG,
121
122
impurity_importance :: Bool = true ) where {S, T}
122
123
@@ -133,7 +134,7 @@ function build_stump(
133
134
min_purity_increase = 0.0 ;
134
135
rng = rng)
135
136
136
- return _build_tree (t, labels, size (features, 2 ), size (features, 1 ), impurity_importance)
137
+ return _build_tree (t, labels, n_classes, size (features, 2 ), size (features, 1 ), impurity_importance)
137
138
end
138
139
139
140
function build_tree (
@@ -144,6 +145,7 @@ function build_tree(
144
145
min_samples_leaf = 1 ,
145
146
min_samples_split = 2 ,
146
147
min_purity_increase = 0.0 ;
148
+ n_classes :: Int = length (unique (labels)),
147
149
loss = util. entropy :: Function ,
148
150
rng = Random. GLOBAL_RNG,
149
151
impurity_importance :: Bool = true ) where {S, T}
@@ -168,18 +170,18 @@ function build_tree(
168
170
min_purity_increase = Float64 (min_purity_increase),
169
171
rng = rng)
170
172
171
- return _build_tree (t, labels, size (features, 2 ), size (features, 1 ), impurity_importance)
173
+ return _build_tree (t, labels, n_classes, size (features, 2 ), size (features, 1 ), impurity_importance)
172
174
end
173
175
174
176
function _build_tree (
175
177
tree:: treeclassifier.Tree{S, T} ,
176
178
labels:: AbstractVector{T} ,
179
+ n_classes:: Int ,
177
180
n_features,
178
181
n_samples,
179
182
impurity_importance:: Bool
180
183
) where {S, T}
181
184
node = _convert (tree. root, tree. list, labels[tree. labels])
182
- n_classes = unique (labels) |> length
183
185
if ! impurity_importance
184
186
return Root {S, T, n_classes} (node, n_features, Float64[])
185
187
else
@@ -237,15 +239,15 @@ function prune_tree(
237
239
if ! isempty (fi)
238
240
update_pruned_impurity! (tree, fi, ntt, loss)
239
241
end
240
- return Leaf {T, N} (tree. left. features , majority, combined, total)
242
+ return Leaf {T, N} (tree. left. classes , majority, combined, total)
241
243
else
242
244
return tree
243
245
end
244
246
end
245
247
function _prune_run (tree:: Root{S, T, N} , purity_thresh:: Real ) where {S, T, N}
246
248
fi = deepcopy (tree. featim) # # recalculate feature importances
247
249
node = _prune_run (tree. node, purity_thresh, fi)
248
- return Root {S, T, N} (node, fi)
250
+ return Root {S, T, N} (node, tree . n_feat, fi)
249
251
end
250
252
function _prune_run (
251
253
tree:: LeafOrNode{S, T, N} ,
@@ -273,7 +275,7 @@ function prune_tree(
273
275
end
274
276
275
277
276
- apply_tree (leaf:: Leaf , feature:: AbstractVector ) = leaf. features [leaf. majority]
278
+ apply_tree (leaf:: Leaf , feature:: AbstractVector ) = leaf. classes [leaf. majority]
277
279
apply_tree (
278
280
tree:: Root{S, T} ,
279
281
features:: AbstractVector{S}
@@ -369,10 +371,11 @@ function build_forest(
369
371
370
372
t_samples = length (labels)
371
373
n_samples = floor (Int, partial_sampling * t_samples)
374
+ n_classes = length (unique (labels))
372
375
373
376
forest = impurity_importance ?
374
- Vector {Root{S, T}} (undef, n_trees) :
375
- Vector {LeafOrNode{S, T}} (undef, n_trees)
377
+ Vector {Root{S, T, n_classes }} (undef, n_trees) :
378
+ Vector {LeafOrNode{S, T, n_classes }} (undef, n_trees)
376
379
377
380
entropy_terms = util. compute_entropy_terms (n_samples)
378
381
loss = (ns, n) -> util. entropy (ns, n, entropy_terms)
@@ -390,7 +393,8 @@ function build_forest(
390
393
max_depth,
391
394
min_samples_leaf,
392
395
min_samples_split,
393
- min_purity_increase,
396
+ min_purity_increase;
397
+ n_classes,
394
398
loss = loss,
395
399
rng = _rng,
396
400
impurity_importance = impurity_importance)
@@ -406,7 +410,8 @@ function build_forest(
406
410
max_depth,
407
411
min_samples_leaf,
408
412
min_samples_split,
409
- min_purity_increase,
413
+ min_purity_increase;
414
+ n_classes,
410
415
loss = loss,
411
416
impurity_importance = impurity_importance)
412
417
end
@@ -416,13 +421,13 @@ function build_forest(
416
421
end
417
422
418
423
function _build_forest (
419
- forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}} ,
424
+ forest :: Vector{<: Union{Root{S, T, N }, LeafOrNode{S, T, N }}} ,
420
425
n_features ,
421
426
n_trees ,
422
- impurity_importance :: Bool ) where {S, T}
427
+ impurity_importance :: Bool ) where {S, T, N }
423
428
424
429
if ! impurity_importance
425
- return Ensemble {S, T} (forest, n_features, Float64[])
430
+ return Ensemble {S, T, N } (forest, n_features, Float64[])
426
431
else
427
432
fi = zeros (Float64, n_features)
428
433
for tree in forest
@@ -432,12 +437,12 @@ function _build_forest(
432
437
end
433
438
end
434
439
435
- forest_new = Vector {LeafOrNode{S, T}} (undef, n_trees)
440
+ forest_new = Vector {LeafOrNode{S, T, N }} (undef, n_trees)
436
441
Threads. @threads for i in 1 : n_trees
437
442
forest_new[i] = forest[i]. node
438
443
end
439
444
440
- return Ensemble {S, T} (forest_new, n_features, fi ./ n_trees)
445
+ return Ensemble {S, T, N } (forest_new, n_features, fi ./ n_trees)
441
446
end
442
447
end
443
448
@@ -514,11 +519,13 @@ function build_adaboost_stumps(
514
519
stumps = Node{S, T}[]
515
520
coeffs = Float64[]
516
521
n_features = size (features, 2 )
522
+ n_classes = length (unique (labels))
517
523
for i in 1 : n_iterations
518
524
new_stump = build_stump (
519
525
labels,
520
526
features,
521
527
weights;
528
+ n_classes,
522
529
rng= mk_rng (rng),
523
530
impurity_importance= false
524
531
)
@@ -538,7 +545,7 @@ function build_adaboost_stumps(
538
545
break
539
546
end
540
547
end
541
- return (Ensemble {S, T} (stumps, n_features, Float64[]), coeffs)
548
+ return (Ensemble {S, T, n_classes } (stumps, n_features, Float64[]), coeffs)
542
549
end
543
550
544
551
apply_adaboost_stumps (
0 commit comments