Skip to content

Commit 8cee5a9

Browse files
authored
Add inspect_values option (#581)
1 parent 57cd12f commit 8cee5a9

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

lib/axon.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,6 +3761,10 @@ defmodule Axon do
37613761
metrics. Also forwarded to JIT if debug mode is available
37623762
for your chosen compiler or backend. Defaults to `false`
37633763
3764+
* `:print_values` - if `true`, will print intermediate layer
3765+
values to the screen for inspection. This is useful if you need
3766+
to debug intermediate values of a model
3767+
37643768
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
37653769
to control differences in compilation at training or inference time.
37663770
Defaults to `:inference`

lib/axon/compiler.ex

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,15 @@ defmodule Axon.Compiler do
5151
raise_on_none? = Keyword.get(opts, :raise_on_none, true)
5252
mode = Keyword.get(opts, :mode, :inference)
5353
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
54+
print_values = Keyword.get(opts, :print_values, false)
5455
global_layer_options = Keyword.get(opts, :global_layer_options, [])
55-
config = %{mode: mode, debug?: debug?, global_layer_options: global_layer_options}
56+
57+
config = %{
58+
mode: mode,
59+
debug?: debug?,
60+
global_layer_options: global_layer_options,
61+
print_values: print_values
62+
}
5663

5764
{time, {root_id, {cache, _op_counts, _block_cache, model_state_meta}}} =
5865
:timer.tc(fn ->
@@ -446,16 +453,21 @@ defmodule Axon.Compiler do
446453
end
447454

448455
defp recur_model_funs(
449-
%Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: policy},
456+
%Axon.Node{id: id, name: name_fn, op: :constant, opts: [value: tensor], policy: policy},
450457
_nodes,
451458
{cache, op_counts, block_cache, model_state_meta},
452-
_
459+
%{print_values: print_values}
453460
) do
461+
name = name_fn.(:constant, op_counts)
454462
op_counts = Map.update(op_counts, :constant, 1, fn x -> x + 1 end)
455463
tensor = Nx.backend_copy(tensor, Nx.BinaryBackend)
456464

457465
predict_fun = fn _params, _inputs, state, _cache, result_cache, _fn_stacktrace ->
458-
out = safe_policy_cast(tensor, policy, :output)
466+
out =
467+
tensor
468+
|> safe_policy_cast(policy, :output)
469+
|> maybe_print_values(name, print_values)
470+
459471
{out, {state, result_cache}}
460472
end
461473

@@ -477,7 +489,7 @@ defmodule Axon.Compiler do
477489
},
478490
_nodes,
479491
{cache, op_counts, block_cache, model_state_meta},
480-
%{mode: mode}
492+
%{mode: mode, print_values: print_values}
481493
) do
482494
name = name_fn.(:input, op_counts)
483495
op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end)
@@ -492,6 +504,7 @@ defmodule Axon.Compiler do
492504
value
493505
|> apply_hooks(:forward, mode, hooks)
494506
|> apply_hooks(:backward, mode, hooks)
507+
|> maybe_print_values(name, print_values)
495508

496509
{res, {state, result_cache}}
497510
end
@@ -687,6 +700,8 @@ defmodule Axon.Compiler do
687700
Map.put(state, block_name, out_state)
688701
end
689702

703+
out_result = maybe_print_values(out_result, block_name, config.print_values)
704+
690705
{out_result, {state, result_cache}}
691706
end
692707
end
@@ -847,7 +862,12 @@ defmodule Axon.Compiler do
847862
},
848863
nodes,
849864
cache_and_counts,
850-
%{mode: mode, debug?: debug?, global_layer_options: global_layer_options} = config
865+
%{
866+
mode: mode,
867+
debug?: debug?,
868+
global_layer_options: global_layer_options,
869+
print_values: print_values
870+
} = config
851871
)
852872
when (is_function(op) or is_atom(op)) and is_list(inputs) do
853873
# Traverse to accumulate cache and get parent_ids for
@@ -912,6 +932,7 @@ defmodule Axon.Compiler do
912932
hooks,
913933
mode,
914934
global_layer_options,
935+
print_values,
915936
stacktrace
916937
)
917938

@@ -994,6 +1015,7 @@ defmodule Axon.Compiler do
9941015
hooks,
9951016
mode,
9961017
global_layer_options,
1018+
print_values,
9971019
layer_stacktrace
9981020
) do
9991021
# Recurse graph inputs and invoke cache to get parent results,
@@ -1113,6 +1135,8 @@ defmodule Axon.Compiler do
11131135
{new_out, state}
11141136
end
11151137

1138+
out = maybe_print_values(out, name, print_values)
1139+
11161140
{out, {state, result_cache}}
11171141
end
11181142
end
@@ -1270,6 +1294,12 @@ defmodule Axon.Compiler do
12701294
defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param)
12711295
defp maybe_freeze(param, false), do: param
12721296

1297+
defp maybe_print_values(value, layer, true) do
1298+
Nx.Defn.Kernel.print_value(value, label: layer)
1299+
end
1300+
1301+
defp maybe_print_values(value, _, _), do: value
1302+
12731303
defp apply_hooks(res, event, mode, hooks) do
12741304
hooks
12751305
|> Enum.reverse()

test/axon/compiler_test.exs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5716,4 +5716,27 @@ defmodule CompilerTest do
57165716
assert Nx.shape(out) == {1, 20, 32}
57175717
end
57185718
end
5719+
5720+
describe "inspect values" do
5721+
test "prints intermediate layer values to the screen" do
5722+
model =
5723+
Axon.input("x")
5724+
|> Axon.dense(10, name: "foo")
5725+
|> Axon.dense(4, name: "bar")
5726+
5727+
{init_fn, predict_fn} = Axon.build(model, print_values: true)
5728+
input = Nx.broadcast(1, {1, 10})
5729+
5730+
model_state = init_fn.(input, ModelState.empty())
5731+
5732+
out =
5733+
ExUnit.CaptureIO.capture_io(fn ->
5734+
predict_fn.(model_state, input)
5735+
end)
5736+
5737+
assert out =~ "x:"
5738+
assert out =~ "foo:"
5739+
assert out =~ "bar:"
5740+
end
5741+
end
57195742
end

0 commit comments

Comments
 (0)