|
| 1 | +using Pkg |
| 2 | +# To ensure we benchmark the local version of DynamicPPL, dev the folder above. |
| 3 | +Pkg.develop(; path=joinpath(@__DIR__, "..")) |
| 4 | + |
| 5 | +using DynamicPPLBenchmarks: Models, make_suite, model_dimension |
| 6 | +using BenchmarkTools: @benchmark, median, run |
| 7 | +using PrettyTables: PrettyTables, ft_printf |
| 8 | +using StableRNGs: StableRNG |
| 9 | + |
| 10 | +rng = StableRNG(23) |
| 11 | + |
| 12 | +# Create DynamicPPL.Model instances to run benchmarks on. |
| 13 | +smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100)) |
| 14 | +loop_univariate1k, multivariate1k = begin |
| 15 | + data_1k = randn(rng, 1_000) |
| 16 | + loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k) |
| 17 | + multi = Models.multivariate(length(data_1k)) | (; o=data_1k) |
| 18 | + loop, multi |
| 19 | +end |
| 20 | +loop_univariate10k, multivariate10k = begin |
| 21 | + data_10k = randn(rng, 10_000) |
| 22 | + loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k) |
| 23 | + multi = Models.multivariate(length(data_10k)) | (; o=data_10k) |
| 24 | + loop, multi |
| 25 | +end |
| 26 | +lda_instance = begin |
| 27 | + w = [1, 2, 3, 2, 1, 1] |
| 28 | + d = [1, 1, 1, 2, 2, 2] |
| 29 | + Models.lda(2, d, w) |
| 30 | +end |
| 31 | + |
| 32 | +# Specify the combinations to test: |
| 33 | +# (Model Name, model instance, VarInfo choice, AD backend, linked) |
| 34 | +chosen_combinations = [ |
| 35 | + ( |
| 36 | + "Simple assume observe", |
| 37 | + Models.simple_assume_observe(randn(rng)), |
| 38 | + :typed, |
| 39 | + :forwarddiff, |
| 40 | + false, |
| 41 | + ), |
| 42 | + ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), |
| 43 | + ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), |
| 44 | + ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), |
| 45 | + ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), |
| 46 | + ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), |
| 47 | + ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), |
| 48 | + ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), |
| 49 | + ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), |
| 50 | + ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), |
| 51 | + ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), |
| 52 | + ("Dynamic", Models.dynamic(), :typed, :mooncake, true), |
| 53 | + ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), |
| 54 | + ("LDA", lda_instance, :typed, :reversediff, true), |
| 55 | +] |
| 56 | + |
| 57 | +# Time running a model-like function that does not use DynamicPPL, as a reference point. |
| 58 | +# Eval timings will be relative to this. |
| 59 | +reference_time = begin |
| 60 | + obs = randn(rng) |
| 61 | + median(@benchmark Models.simple_assume_observe_non_model(obs)).time |
| 62 | +end |
| 63 | + |
| 64 | +results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] |
| 65 | + |
| 66 | +for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations |
| 67 | + @info "Running benchmark for $model_name" |
| 68 | + suite = make_suite(model, varinfo_choice, adbackend, islinked) |
| 69 | + results = run(suite) |
| 70 | + eval_time = median(results["evaluation"]).time |
| 71 | + relative_eval_time = eval_time / reference_time |
| 72 | + ad_eval_time = median(results["gradient"]).time |
| 73 | + relative_ad_eval_time = ad_eval_time / eval_time |
| 74 | + push!( |
| 75 | + results_table, |
| 76 | + ( |
| 77 | + model_name, |
| 78 | + model_dimension(model, islinked), |
| 79 | + string(adbackend), |
| 80 | + string(varinfo_choice), |
| 81 | + islinked, |
| 82 | + relative_eval_time, |
| 83 | + relative_ad_eval_time, |
| 84 | + ), |
| 85 | + ) |
| 86 | +end |
| 87 | + |
| 88 | +table_matrix = hcat(Iterators.map(collect, zip(results_table...))...) |
| 89 | +header = [ |
| 90 | + "Model", |
| 91 | + "Dimension", |
| 92 | + "AD Backend", |
| 93 | + "VarInfo Type", |
| 94 | + "Linked", |
| 95 | + "Eval Time / Ref Time", |
| 96 | + "AD Time / Eval Time", |
| 97 | +] |
| 98 | +PrettyTables.pretty_table( |
| 99 | + table_matrix; |
| 100 | + header=header, |
| 101 | + tf=PrettyTables.tf_markdown, |
| 102 | + formatters=ft_printf("%.1f", [6, 7]), |
| 103 | +) |
0 commit comments