@@ -51,8 +51,15 @@ defmodule Axon.Compiler do
51
51
raise_on_none? = Keyword . get ( opts , :raise_on_none , true )
52
52
mode = Keyword . get ( opts , :mode , :inference )
53
53
seed = Keyword . get_lazy ( opts , :seed , fn -> :erlang . system_time ( ) end )
54
+ print_values = Keyword . get ( opts , :print_values , false )
54
55
global_layer_options = Keyword . get ( opts , :global_layer_options , [ ] )
55
- config = % { mode: mode , debug?: debug? , global_layer_options: global_layer_options }
56
+
57
+ config = % {
58
+ mode: mode ,
59
+ debug?: debug? ,
60
+ global_layer_options: global_layer_options ,
61
+ print_values: print_values
62
+ }
56
63
57
64
{ time , { root_id , { cache , _op_counts , _block_cache , model_state_meta } } } =
58
65
:timer . tc ( fn ->
@@ -446,16 +453,21 @@ defmodule Axon.Compiler do
446
453
end
447
454
448
455
defp recur_model_funs (
449
- % Axon.Node { id: id , op: :constant , opts: [ value: tensor ] , policy: policy } ,
456
+ % Axon.Node { id: id , name: name_fn , op: :constant , opts: [ value: tensor ] , policy: policy } ,
450
457
_nodes ,
451
458
{ cache , op_counts , block_cache , model_state_meta } ,
452
- _
459
+ % { print_values: print_values }
453
460
) do
461
+ name = name_fn . ( :constant , op_counts )
454
462
op_counts = Map . update ( op_counts , :constant , 1 , fn x -> x + 1 end )
455
463
tensor = Nx . backend_copy ( tensor , Nx.BinaryBackend )
456
464
457
465
predict_fun = fn _params , _inputs , state , _cache , result_cache , _fn_stacktrace ->
458
- out = safe_policy_cast ( tensor , policy , :output )
466
+ out =
467
+ tensor
468
+ |> safe_policy_cast ( policy , :output )
469
+ |> maybe_print_values ( name , print_values )
470
+
459
471
{ out , { state , result_cache } }
460
472
end
461
473
@@ -477,7 +489,7 @@ defmodule Axon.Compiler do
477
489
} ,
478
490
_nodes ,
479
491
{ cache , op_counts , block_cache , model_state_meta } ,
480
- % { mode: mode }
492
+ % { mode: mode , print_values: print_values }
481
493
) do
482
494
name = name_fn . ( :input , op_counts )
483
495
op_counts = Map . update ( op_counts , :input , 1 , fn x -> x + 1 end )
@@ -492,6 +504,7 @@ defmodule Axon.Compiler do
492
504
value
493
505
|> apply_hooks ( :forward , mode , hooks )
494
506
|> apply_hooks ( :backward , mode , hooks )
507
+ |> maybe_print_values ( name , print_values )
495
508
496
509
{ res , { state , result_cache } }
497
510
end
@@ -687,6 +700,8 @@ defmodule Axon.Compiler do
687
700
Map . put ( state , block_name , out_state )
688
701
end
689
702
703
+ out_result = maybe_print_values ( out_result , block_name , config . print_values )
704
+
690
705
{ out_result , { state , result_cache } }
691
706
end
692
707
end
@@ -847,7 +862,12 @@ defmodule Axon.Compiler do
847
862
} ,
848
863
nodes ,
849
864
cache_and_counts ,
850
- % { mode: mode , debug?: debug? , global_layer_options: global_layer_options } = config
865
+ % {
866
+ mode: mode ,
867
+ debug?: debug? ,
868
+ global_layer_options: global_layer_options ,
869
+ print_values: print_values
870
+ } = config
851
871
)
852
872
when ( is_function ( op ) or is_atom ( op ) ) and is_list ( inputs ) do
853
873
# Traverse to accumulate cache and get parent_ids for
@@ -912,6 +932,7 @@ defmodule Axon.Compiler do
912
932
hooks ,
913
933
mode ,
914
934
global_layer_options ,
935
+ print_values ,
915
936
stacktrace
916
937
)
917
938
@@ -994,6 +1015,7 @@ defmodule Axon.Compiler do
994
1015
hooks ,
995
1016
mode ,
996
1017
global_layer_options ,
1018
+ print_values ,
997
1019
layer_stacktrace
998
1020
) do
999
1021
# Recurse graph inputs and invoke cache to get parent results,
@@ -1113,6 +1135,8 @@ defmodule Axon.Compiler do
1113
1135
{ new_out , state }
1114
1136
end
1115
1137
1138
+ out = maybe_print_values ( out , name , print_values )
1139
+
1116
1140
{ out , { state , result_cache } }
1117
1141
end
1118
1142
end
@@ -1270,6 +1294,12 @@ defmodule Axon.Compiler do
1270
1294
defp maybe_freeze ( param , true ) , do: Nx.Defn.Kernel . stop_grad ( param )
1271
1295
defp maybe_freeze ( param , false ) , do: param
1272
1296
1297
+ defp maybe_print_values ( value , layer , true ) do
1298
+ Nx.Defn.Kernel . print_value ( value , label: layer )
1299
+ end
1300
+
1301
+ defp maybe_print_values ( value , _ , _ ) , do: value
1302
+
1273
1303
defp apply_hooks ( res , event , mode , hooks ) do
1274
1304
hooks
1275
1305
|> Enum . reverse ( )
0 commit comments