diff --git a/Project.toml b/Project.toml index 20248eb..de99baa 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MultiData = "8cc5100c-b3d1-4f82-90cb-0ea93d317aba" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Query = "1a8c2f83-1ff3-5112-b086-8aa67b057ba1" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/scalar/main.jl b/src/scalar/main.jl index bb14308..3bec88a 100644 --- a/src/scalar/main.jl +++ b/src/scalar/main.jl @@ -17,6 +17,8 @@ include("random.jl") include("representatives.jl") +include("visualizations.jl") + # # Types for representing common associations between features and operators # include("canonical-conditions.jl") # TODO remove diff --git a/src/scalar/var-features.jl b/src/scalar/var-features.jl index 3a65e57..6ee15ca 100644 --- a/src/scalar/var-features.jl +++ b/src/scalar/var-features.jl @@ -2,7 +2,7 @@ import SoleData: AbstractFeature using MultiData: instance_channel -import Base: show +import Base: show, isless import SoleLogics: syntaxstring # Feature parentheses (e.g., for parsing/showing "main[V2]") @@ -301,6 +301,8 @@ struct VariableValue{I<:VariableId, N<:Union{VariableName, Nothing}} <: Abstract end featurename(f::VariableValue) = !isnothing(f.i_name) ? f.i_name : "" +Base.isless(a::VariableValue, b::VariableValue) = isless(a.i_variable, b.i_variable) + function syntaxstring(f::VariableValue; variable_names_map = nothing, show_colon = false, kwargs...) if !isnothing(f.i_name) opening_parenthesis = UVF_OPENING_PARENTHESIS diff --git a/src/scalar/visualizations.jl b/src/scalar/visualizations.jl new file mode 100644 index 0000000..5d5bdf4 --- /dev/null +++ b/src/scalar/visualizations.jl @@ -0,0 +1,150 @@ +export show_scalardnf + +using Printf: @sprintf +using SoleData: AbstractScalarCondition + +function extract_ranges(conjunction::LeftmostConjunctiveForm) + by_var = Dict{Any, Tuple{Float64,Bool,Float64,Bool}}() + for atom in SoleLogics.conjuncts(conjunction) + @assert atom isa Atom + cond = SoleLogics.value(atom) + @assert cond isa AbstractScalarCondition typeof(cond) + feat = SoleData.feature(cond) + minv = isnothing(SoleData.minval(cond)) ? -Inf : SoleData.minval(cond) + maxv = isnothing(SoleData.maxval(cond)) ? Inf : SoleData.maxval(cond) + mini = SoleData.minincluded(cond) + maxi = SoleData.maxincluded(cond) + if haskey(by_var, feat) + # Intersezione: aggiorna i bordi + (curmin, curmini, curmax, curmaxi) = by_var[feat] + # aggiorna min + if minv > curmin + curmin = minv + curmini = mini + elseif minv == curmin + curmini = curmini && mini + end + # aggiorna max + if maxv < curmax + curmax = maxv + curmaxi = maxi + elseif maxv == curmax + curmaxi = curmaxi && maxi + end + by_var[feat] = (curmin, curmini, curmax, curmaxi) + else + by_var[feat] = (minv, mini, maxv, maxi) + end + end + return by_var +end + + +function collect_thresholds(all_ranges) + thresholds = Set{Float64}() + for ranges in all_ranges + for (_, (minv, _mini, maxv, _maxi)) in pairs(ranges) + push!(thresholds, minv) + push!(thresholds, maxv) + end + end + return sort(collect(thresholds)) +end + +function draw_bar(minv, mini, maxv, maxi, thresholds; colwidth=5, body_char = "-") + nseg = length(thresholds) - 1 + segments = fill(" " ^ colwidth, nseg) + + for i = 1:nseg + t0 = thresholds[i] + t1 = thresholds[i+1] + if maxv <= t0 || minv >= t1 + continue + else + segments[i] = body_char ^ colwidth + end + end + + first_idx = findfirst(s -> occursin(body_char, s), segments) + last_idx = findlast(s -> occursin(body_char, s), segments) + + if first_idx !== nothing + segments[first_idx] = let x = collect(segments[first_idx]) + x[1] = (mini ? '[' : '(') + String(x) + end + end + if last_idx !== nothing + segments[last_idx] = let x = collect(segments[last_idx]) + x[colwidth] = (maxi ? ']' : ')') + String(x) + end + end + + return " " ^ colwidth * join(segments) +end + + +show_scalardnf(f::DNF; kwargs...) = show_scalardnf(stdout, f; kwargs...) + +function show_scalardnf( + io::IO, + formula::DNF; + show_unbounded=true, + colwidth=5, + body_char='=', # alternatives: ■, ━ + print_disjuncts=false, + palette=[:cyan, :green, :yellow, :magenta, :blue] +) + @assert colwidth >= 5 + formula = normalize(formula) + disjs = SoleLogics.disjuncts(formula) + all_ranges = [extract_ranges(d) for d in disjs] + + # raccogli tutte le variabili + all_vars = Set{Any}() + for ranges in all_ranges, v in keys(ranges) + push!(all_vars, v) + end + all_vars = sort(collect(all_vars)) + thresholds = collect_thresholds(all_ranges) + + # calcola larghezza massima nome variabile + namewidth = maximum(length(syntaxstring(v)) for v in all_vars) + + # header + header = " " ^ (3+colwidth+namewidth) + for t in thresholds + header *= @sprintf("%-*.*f", colwidth, 2, t) + end + println(io, header) + println(io) + + # mappa variabili -> colori + colors=Dict(), + var_colors = Dict{Any,Symbol}() + for (i, v) in enumerate(all_vars) + var_colors[v] = get(colors, v, palette[(i-1) % length(palette) + 1]) + end + + for (i, (d, ranges)) in enumerate(zip(disjs, all_ranges)) + print_disjuncts && println(io, "Disjunct $i: ", syntaxstring(normalize(d))) + for v in all_vars + if haskey(ranges, v) + (minv, mini, maxv, maxi) = ranges[v] + elseif show_unbounded + minv, mini, maxv, maxi = (-Inf, true, +Inf, true) + else + continue + end + bar = draw_bar(minv, mini, maxv, maxi, thresholds; colwidth, body_char) + # colore + color = var_colors[v] + # stampo nome e barre + print(io, " ") + printstyled(io, rpad(syntaxstring(v), namewidth), " : ", bar, color=color) + println(io) + end + println(io) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5bc00aa..69aa412 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,7 @@ test_suites = [ ("Conditions", [ "range-scalar-condition.jl", ]), ("Alphabets", [ "scalar-alphabet.jl", "discretization.jl"]), ("Features", [ "patchedfeatures.jl"]), + ("Visualizations", [ "visualizations.jl", ]), # ("MLJ", [ "MLJ.jl", ]), ("PLA", [ "pla.jl", ]), diff --git a/test/visualizations.jl b/test/visualizations.jl new file mode 100644 index 0000000..7e32dcc --- /dev/null +++ b/test/visualizations.jl @@ -0,0 +1,41 @@ +using SoleData + +@testset "Visualizations" begin + f = @scalarformula( + ((V1 < 5.85) ∧ (V1 ≥ 5.65) ∧ (V2 < 2.85) ∧ (V3 < 4.55) ∧ (V3 ≥ 4.45) ∧ (V4 < 0.35)) ∨ + ((V1 < 5.3) ∧ (V2 ≥ 2.85) ∧ (V3 < 5.05) ∧ (V3 ≥ 4.85) ∧ (V4 < 0.35)) + ) |> dnf + + io = IOBuffer() + show_scalardnf( + io, + f; + show_unbounded=true, + colwidth=6, + ) + + out = String(take!(io)) + + check_same(a, b) = myclean(a) == myclean(b) + function myclean(a) + lines = collect(filter(x->x != "", split(out, '\n'))) + join(map(x -> rstrip(x), lines), "\n") + end + + @test check_same(out, """ + -Inf 0.35 2.85 4.45 4.55 4.85 5.05 5.30 5.65 5.85 Inf + + V1 : [====) + V2 : [==========) + V3 : [====) + V4 : [====) + + V1 : [========================================) + V2 : [==============================================] + V3 : [====) + V4 : [====) + + + """) + +end \ No newline at end of file