Skip to content

Commit 700afd2

Browse files
committed
Fix more test results
1 parent 4f94a27 commit 700afd2

File tree

5 files changed

+54
-38
lines changed

5 files changed

+54
-38
lines changed

src/DecisionTree.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export InfoNode, InfoLeaf, wrap
2727
########## Types ##########
2828

2929
struct Leaf{T, N}
30-
features :: NTuple{N, T}
30+
classes :: NTuple{N, T}
3131
majority :: Int
3232
values :: NTuple{N, Int}
3333
total :: Int
@@ -54,15 +54,20 @@ struct Ensemble{S, T, N}
5454
featim :: Vector{Float64}
5555
end
5656

57-
Leaf(features::NTuple{T, N}) where {T, N} =
57+
Leaf(features::NTuple{N, T}) where {T, N} =
5858
Leaf(features, 0, Tuple(zeros(T, N)), 0)
59+
Leaf(features::NTuple{N, T}, frequencies::NTuple{N, Int}) where {T, N} =
60+
Leaf(features, argmax(frequencies), frequencies, sum(frequencies))
61+
Leaf(features::Union{<:AbstractVector, <:Tuple},
62+
frequencies::Union{<:AbstractVector{Int}, <:Tuple}) =
63+
Leaf(Tuple(features), Tuple(frequencies))
5964

6065
is_leaf(l::Leaf) = true
6166
is_leaf(n::Node) = false
6267

6368
_zero(::Type{String}) = ""
6469
_zero(x::Any) = zero(x)
65-
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.features))
70+
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.classes))
6671
convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[])
6772
convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node
6873
promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N}
@@ -101,7 +106,7 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
101106
depth(tree::Root) = depth(tree.node)
102107

103108
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
104-
println(io, leaf.features[leaf.majority], " : ",
109+
println(io, leaf.classes[leaf.majority], " : ",
105110
leaf.values[leaf.majority], '/', leaf.total)
106111
end
107112
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
@@ -165,7 +170,7 @@ end
165170

166171
function show(io::IO, leaf::Leaf)
167172
println(io, "Decision Leaf")
168-
println(io, "Majority: ", leaf.features[leaf.majority])
173+
println(io, "Majority: ", leaf.classes[leaf.majority])
169174
print(io, "Samples: ", leaf.total)
170175
end
171176

src/classification/main.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ function _convert(
4141
) where {S, T}
4242

4343
if node.is_leaf
44-
featfreq = Tuple(sum(labels[node.region] .== l) for l in list)
44+
classfreq = Tuple(sum(labels[node.region] .== l) for l in list)
4545
return Leaf{T, length(list)}(
46-
Tuple(list), argmax(featfreq), featfreq, length(node.region))
46+
Tuple(list), argmax(classfreq), classfreq, length(node.region))
4747
else
4848
left = _convert(node.l, list, labels)
4949
right = _convert(node.r, list, labels)
@@ -117,6 +117,7 @@ function build_stump(
117117
labels :: AbstractVector{T},
118118
features :: AbstractMatrix{S},
119119
weights = nothing;
120+
n_classes :: Int = length(unique(labels)),
120121
rng = Random.GLOBAL_RNG,
121122
impurity_importance :: Bool = true) where {S, T}
122123

@@ -133,7 +134,7 @@ function build_stump(
133134
min_purity_increase = 0.0;
134135
rng = rng)
135136

136-
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
137+
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
137138
end
138139

139140
function build_tree(
@@ -144,6 +145,7 @@ function build_tree(
144145
min_samples_leaf = 1,
145146
min_samples_split = 2,
146147
min_purity_increase = 0.0;
148+
n_classes :: Int = length(unique(labels)),
147149
loss = util.entropy :: Function,
148150
rng = Random.GLOBAL_RNG,
149151
impurity_importance :: Bool = true) where {S, T}
@@ -168,18 +170,18 @@ function build_tree(
168170
min_purity_increase = Float64(min_purity_increase),
169171
rng = rng)
170172

171-
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
173+
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
172174
end
173175

174176
function _build_tree(
175177
tree::treeclassifier.Tree{S, T},
176178
labels::AbstractVector{T},
179+
n_classes::Int,
177180
n_features,
178181
n_samples,
179182
impurity_importance::Bool
180183
) where {S, T}
181184
node = _convert(tree.root, tree.list, labels[tree.labels])
182-
n_classes = unique(labels) |> length
183185
if !impurity_importance
184186
return Root{S, T, n_classes}(node, n_features, Float64[])
185187
else
@@ -237,15 +239,15 @@ function prune_tree(
237239
if !isempty(fi)
238240
update_pruned_impurity!(tree, fi, ntt, loss)
239241
end
240-
return Leaf{T, N}(tree.left.features, majority, combined, total)
242+
return Leaf{T, N}(tree.left.classes, majority, combined, total)
241243
else
242244
return tree
243245
end
244246
end
245247
function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N}
246248
fi = deepcopy(tree.featim) ## recalculate feature importances
247249
node = _prune_run(tree.node, purity_thresh, fi)
248-
return Root{S, T, N}(node, fi)
250+
return Root{S, T, N}(node, tree.n_feat, fi)
249251
end
250252
function _prune_run(
251253
tree::LeafOrNode{S, T, N},
@@ -273,7 +275,7 @@ function prune_tree(
273275
end
274276

275277

276-
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority]
278+
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.classes[leaf.majority]
277279
apply_tree(
278280
tree::Root{S, T},
279281
features::AbstractVector{S}
@@ -369,10 +371,11 @@ function build_forest(
369371

370372
t_samples = length(labels)
371373
n_samples = floor(Int, partial_sampling * t_samples)
374+
n_classes = length(unique(labels))
372375

373376
forest = impurity_importance ?
374-
Vector{Root{S, T}}(undef, n_trees) :
375-
Vector{LeafOrNode{S, T}}(undef, n_trees)
377+
Vector{Root{S, T, n_classes}}(undef, n_trees) :
378+
Vector{LeafOrNode{S, T, n_classes}}(undef, n_trees)
376379

377380
entropy_terms = util.compute_entropy_terms(n_samples)
378381
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
@@ -390,7 +393,8 @@ function build_forest(
390393
max_depth,
391394
min_samples_leaf,
392395
min_samples_split,
393-
min_purity_increase,
396+
min_purity_increase;
397+
n_classes,
394398
loss = loss,
395399
rng = _rng,
396400
impurity_importance = impurity_importance)
@@ -406,7 +410,8 @@ function build_forest(
406410
max_depth,
407411
min_samples_leaf,
408412
min_samples_split,
409-
min_purity_increase,
413+
min_purity_increase;
414+
n_classes,
410415
loss = loss,
411416
impurity_importance = impurity_importance)
412417
end
@@ -416,13 +421,13 @@ function build_forest(
416421
end
417422

418423
function _build_forest(
419-
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
424+
forest :: Vector{<: Union{Root{S, T, N}, LeafOrNode{S, T, N}}},
420425
n_features ,
421426
n_trees ,
422-
impurity_importance :: Bool) where {S, T}
427+
impurity_importance :: Bool) where {S, T, N}
423428

424429
if !impurity_importance
425-
return Ensemble{S, T}(forest, n_features, Float64[])
430+
return Ensemble{S, T, N}(forest, n_features, Float64[])
426431
else
427432
fi = zeros(Float64, n_features)
428433
for tree in forest
@@ -432,12 +437,12 @@ function _build_forest(
432437
end
433438
end
434439

435-
forest_new = Vector{LeafOrNode{S, T}}(undef, n_trees)
440+
forest_new = Vector{LeafOrNode{S, T, N}}(undef, n_trees)
436441
Threads.@threads for i in 1:n_trees
437442
forest_new[i] = forest[i].node
438443
end
439444

440-
return Ensemble{S, T}(forest_new, n_features, fi ./ n_trees)
445+
return Ensemble{S, T, N}(forest_new, n_features, fi ./ n_trees)
441446
end
442447
end
443448

@@ -514,11 +519,13 @@ function build_adaboost_stumps(
514519
stumps = Node{S, T}[]
515520
coeffs = Float64[]
516521
n_features = size(features, 2)
522+
n_classes = length(unique(labels))
517523
for i in 1:n_iterations
518524
new_stump = build_stump(
519525
labels,
520526
features,
521527
weights;
528+
n_classes,
522529
rng=mk_rng(rng),
523530
impurity_importance=false
524531
)
@@ -538,7 +545,7 @@ function build_adaboost_stumps(
538545
break
539546
end
540547
end
541-
return (Ensemble{S, T}(stumps, n_features, Float64[]), coeffs)
548+
return (Ensemble{S, T, n_classes}(stumps, n_features, Float64[]), coeffs)
542549
end
543550

544551
apply_adaboost_stumps(

src/regression/main.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
include("tree.jl")
22

33
function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64}
4+
classes = Tuple(unique(labels))
45
if node.is_leaf
5-
features = Tuple(unique(labels))
6-
featfreq = Tuple(sum(labels[node.region] .== f) for f in features)
7-
return Leaf{T, length(features)}(
8-
features, argmax(featfreq), featfreq, length(node.region))
6+
classfreq = Tuple(sum(labels[node.region] .== f) for f in classes)
7+
return Leaf{T, length(classes)}(
8+
classes, argmax(classfreq), classfreq, length(node.region))
99
else
1010
left = _convert(node.l, labels)
1111
right = _convert(node.r, labels)
12-
return Node{S, T}(node.feature, node.threshold, left, right)
12+
return Node{S, T, length(classes)}(node.feature, node.threshold, left, right)
1313
end
1414
end
1515

@@ -34,6 +34,7 @@ function build_tree(
3434
min_samples_leaf = 5,
3535
min_samples_split = 2,
3636
min_purity_increase = 0.0;
37+
n_classes :: Int = length(unique(labels)),
3738
rng = Random.GLOBAL_RNG,
3839
impurity_importance:: Bool = true) where {S, T <: Float64}
3940

@@ -59,11 +60,11 @@ function build_tree(
5960
node = _convert(t.root, labels[t.labels])
6061
n_features = size(features, 2)
6162
if !impurity_importance
62-
return Root{S, T}(node, n_features, Float64[])
63+
return Root{S, T, n_classes}(node, n_features, Float64[])
6364
else
6465
fi = zeros(Float64, n_features)
6566
update_using_impurity!(fi, t.root)
66-
return Root{S, T}(node, n_features, fi ./ size(features, 1))
67+
return Root{S, T, n_classes}(node, n_features, fi ./ size(features, 1))
6768
end
6869
end
6970

@@ -77,6 +78,7 @@ function build_forest(
7778
min_samples_leaf = 5,
7879
min_samples_split = 2,
7980
min_purity_increase = 0.0;
81+
n_classes :: Int = length(unique(labels)),
8082
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG,
8183
impurity_importance :: Bool = true) where {S, T <: Float64}
8284

@@ -110,7 +112,8 @@ function build_forest(
110112
max_depth,
111113
min_samples_leaf,
112114
min_samples_split,
113-
min_purity_increase,
115+
min_purity_increase;
116+
n_classes,
114117
rng = _rng,
115118
impurity_importance = impurity_importance)
116119
end
@@ -125,7 +128,8 @@ function build_forest(
125128
max_depth,
126129
min_samples_leaf,
127130
min_samples_split,
128-
min_purity_increase,
131+
min_purity_increase;
132+
n_classes,
129133
impurity_importance = impurity_importance)
130134
end
131135
end

test/miscellaneous/abstract_trees_test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedde
1717
check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool)
1818

1919
@info("Test base functionality")
20-
l1 = Leaf(1, [1,1,2])
21-
l2 = Leaf(2, [1,2,2])
22-
l3 = Leaf(3, [3,3,1])
20+
l1 = Leaf((1,2,3), 1, (2, 1, 0), 3)
21+
l2 = Leaf((1,2,3), 2, (1, 2, 0), 3)
22+
l3 = Leaf((1,2,3), 3, (1, 0, 2), 3)
2323
n2 = Node(2, 0.5, l2, l3)
2424
n1 = Node(1, 0.7, l1, n2)
2525
feature_names = ["firstFt", "secondFt"]
@@ -81,4 +81,4 @@ end
8181
traverse_tree(leaf::InfoLeaf) = nothing
8282

8383
traverse_tree(wrapped_tree)
84-
end
84+
end

test/miscellaneous/convert.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
@testset "convert.jl" begin
44

5-
lf = Leaf(1, [1])
5+
lf = Leaf([1], [1])
66
nv = Node{Int, Int}[]
77
rv = Root{Int, Int}[]
88
push!(nv, lf)
@@ -22,7 +22,7 @@ push!(rv, nv[1])
2222
@test apply_tree(rv[1], [0]) == 1.0
2323
@test apply_tree(rv[2], [0]) == 1.0
2424

25-
lf = Leaf("A", ["B", "A"])
25+
lf = Leaf(["A", "B"], [2, 1])
2626
nv = Node{Int, String}[]
2727
rv = Root{Int, String}[]
2828
push!(nv, lf)

0 commit comments

Comments
 (0)