@@ -594,9 +594,8 @@ def __init__(
594
594
vit_window_size = 8 ,
595
595
vit_mbconv_expansion_rate = 4 ,
596
596
vit_mbconv_shrinkage_rate = 0.25 ,
597
- sparse_input_2496_channels = 8 ,
598
- dense_input_2496_channels = 8 ,
599
- dense_input_4996_channels = 8 ,
597
+ input_2496_channels = 2 + 14 + 1 + 2 + 20 ,
598
+ input_4996_channels = 16 + 1 ,
600
599
surface_and_hrrr_target_spatial_size = 128 ,
601
600
precipitation_target_bins : Dict [str , int ] = dict (
602
601
mrms_rate = 512 ,
@@ -610,26 +609,26 @@ def __init__(
610
609
omo_wind_component_y = 256 ,
611
610
omo_wind_direction = 180
612
611
),
613
- hrrr_target_channels = 617 ,
612
+ hrrr_channels = 617 ,
614
613
hrrr_norm_statistics : Optional [Tensor ] = None ,
615
614
hrrr_loss_weight = 10 ,
616
615
crop_size_post_16km = 48 ,
617
616
resnet_block_depth = 2 ,
618
617
):
619
618
super ().__init__ ()
620
- self .sparse_input_2496_shape = (sparse_input_2496_channels , input_spatial_size , input_spatial_size )
621
- self .dense_input_2496_shape = (dense_input_2496_channels , input_spatial_size , input_spatial_size )
622
- self .dense_input_4996_shape = (dense_input_4996_channels , input_spatial_size , input_spatial_size )
619
+ self .hrrr_input_2496_shape = (hrrr_channels , input_spatial_size , input_spatial_size )
620
+ self .input_2496_shape = (input_2496_channels , input_spatial_size , input_spatial_size )
621
+ self .input_4996_shape = (input_4996_channels , input_spatial_size , input_spatial_size )
623
622
624
623
self .surface_and_hrrr_target_spatial_size = surface_and_hrrr_target_spatial_size
625
624
626
625
self .surface_target_shape = ((self .surface_and_hrrr_target_spatial_size ,) * 2 )
627
- self .hrrr_target_shape = (hrrr_target_channels , * self .surface_target_shape )
626
+ self .hrrr_target_shape = (hrrr_channels , * self .surface_target_shape )
628
627
self .precipitation_target_shape = (surface_and_hrrr_target_spatial_size * 4 ,) * 2
629
628
630
629
self .lead_time_embedding = nn .Embedding (num_lead_times , lead_time_embed_dim )
631
630
632
- dim_in_4km = sparse_input_2496_channels + dense_input_2496_channels
631
+ dim_in_4km = hrrr_channels + input_2496_channels + 1
633
632
634
633
self .to_skip_connect_4km = CenterCrop (crop_size_post_16km * 4 )
635
634
@@ -645,7 +644,7 @@ def __init__(
645
644
CenterPad (input_spatial_size )
646
645
)
647
646
648
- dim_in_8km = dense_input_4996_channels + dim
647
+ dim_in_8km = input_4996_channels + dim
649
648
650
649
self .resnet_blocks_down_8km = ResnetBlocks (
651
650
dim = dim ,
@@ -726,53 +725,79 @@ def __init__(
726
725
727
726
self .to_hrrr_pred = Sequential (
728
727
ChanLayerNorm (dim ),
729
- nn .Conv2d (dim , hrrr_target_channels , 1 )
728
+ nn .Conv2d (dim , hrrr_channels , 1 )
730
729
)
731
730
732
731
# they scale hrrr loss by 10. but also divided by number of channels
733
732
734
- self .hrrr_loss_weight = hrrr_loss_weight / hrrr_target_channels
733
+ self .hrrr_loss_weight = hrrr_loss_weight / hrrr_channels
735
734
736
735
self .has_hrrr_norm_statistics = exists (hrrr_norm_statistics )
737
736
738
737
if self .has_hrrr_norm_statistics :
739
- assert hrrr_norm_statistics .shape == (2 , hrrr_target_channels ), f'normalization statistics must be of shape (2, { normed_hrrr_target } ), containing mean and variance of each target calculated from the dataset'
738
+ assert hrrr_norm_statistics .shape == (2 , hrrr_channels ), f'normalization statistics must be of shape (2, { normed_hrrr_target } ), containing mean and variance of each target calculated from the dataset'
740
739
self .register_buffer ('hrrr_norm_statistics' , hrrr_norm_statistics )
741
740
else :
742
- self .batchnorm_hrrr = MaybeSyncBatchnorm2d ()(hrrr_target_channels , affine = False )
741
+ self .batchnorm_hrrr = MaybeSyncBatchnorm2d ()(hrrr_channels , affine = False )
743
742
744
743
self .mse_loss_scaler = LossScaler ()
745
744
746
745
@beartype
747
746
def forward (
748
747
self ,
748
+ * ,
749
749
lead_times ,
750
- sparse_input_2496 ,
751
- dense_input_2496 ,
752
- dense_input_4996 ,
750
+ hrrr_input_2496 ,
751
+ hrrr_stale_state ,
752
+ input_2496 ,
753
+ input_4996 ,
753
754
surface_targets : Optional [Dict [str , Tensor ]] = None ,
754
755
precipitation_targets : Optional [Dict [str , Tensor ]] = None ,
755
756
hrrr_target : Optional [Tensor ] = None ,
756
757
):
757
758
batch = lead_times .shape [0 ]
758
759
759
- assert batch == sparse_input_2496 .shape [0 ] == dense_input_2496 .shape [0 ] == dense_input_4996 .shape [0 ], 'batch size across all inputs must be the same'
760
+ assert batch == hrrr_input_2496 .shape [0 ] == input_2496 .shape [0 ] == input_4996 .shape [0 ], 'batch size across all inputs must be the same'
760
761
761
- assert sparse_input_2496 .shape [1 :] == self .sparse_input_2496_shape
762
- assert dense_input_2496 .shape [1 :] == self .dense_input_2496_shape
763
- assert dense_input_4996 .shape [1 :] == self .dense_input_4996_shape
762
+ assert hrrr_input_2496 .shape [1 :] == self .hrrr_input_2496_shape
763
+ assert input_2496 .shape [1 :] == self .input_2496_shape
764
+ assert input_4996 .shape [1 :] == self .input_4996_shape
765
+
766
+ # normalize HRRR input and target as needed
767
+
768
+ if self .has_hrrr_norm_statistics :
769
+ mean , variance = self .hrrr_norm_statistics
770
+ mean = rearrange (mean , 'c -> c 1 1' )
771
+ variance = rearrange (variance , 'c -> c 1 1' )
772
+ inv_std = variance .clamp (min = 1e-5 ).rsqrt ()
773
+
774
+ normed_hrrr_input = (hrrr_input_2496 - mean ) * inv_std
775
+
776
+ if exists (hrrr_target ):
777
+ normed_hrrr_target = (hrrr_target - mean ) * inv_std
778
+
779
+ else :
780
+ # use a batchnorm to normalize each channel to mean zero and unit variance
781
+
782
+ with freeze_batchnorm (self .batchnorm_hrrr ) as frozen_batchnorm :
783
+ normed_hrrr_input = frozen_batchnorm (hrrr_input_2496 )
784
+
785
+ if exists (hrrr_target ):
786
+ normed_hrrr_target = frozen_batchnorm (hrrr_target )
787
+
788
+ # main network
764
789
765
790
cond = self .lead_time_embedding (lead_times )
766
791
767
- x = torch .cat ((sparse_input_2496 , dense_input_2496 ), dim = 1 )
792
+ x = torch .cat ((normed_hrrr_input , hrrr_stale_state , input_2496 ), dim = 1 )
768
793
769
794
skip_connect_4km = self .to_skip_connect_4km (x )
770
795
771
796
x = self .resnet_blocks_down_4km (x , cond = cond )
772
797
773
798
x = self .downsample_and_pad_to_8km (x )
774
799
775
- x = torch .cat ((dense_input_4996 , x ), dim = 1 )
800
+ x = torch .cat ((input_4996 , x ), dim = 1 )
776
801
777
802
skip_connect_8km = self .to_skip_connect_8km (x )
778
803
@@ -866,30 +891,16 @@ def forward(
866
891
ce_losses = ce_losses + precipition_loss
867
892
868
893
# calculate HRRR mse loss
894
+ # proposed loss gradient rescaler from section 4.3.2
869
895
870
- if self .has_hrrr_norm_statistics :
871
- mean , variance = self .hrrr_norm_statistics
872
- mean = rearrange (mean , 'c -> c 1 1' )
873
- variance = rearrange (variance , 'c -> c 1 1' )
874
- inv_std = variance .clamp (min = 1e-5 ).rsqrt ()
875
-
876
- normed_hrrr_target = (hrrr_target - mean ) * inv_std
877
- normed_hrrr_pred = (hrrr_pred - mean ) * inv_std
878
- else :
879
- # use a batchnorm to normalize each channel to mean zero and unit variance
880
-
881
- if self .training :
882
- _ = self .batchnorm_hrrr (hrrr_target )
883
-
884
- with freeze_batchnorm (self .batchnorm_hrrr ) as frozen_batchnorm :
885
- normed_hrrr_pred = frozen_batchnorm (hrrr_pred )
886
- normed_hrrr_target = frozen_batchnorm (hrrr_target )
896
+ hrrr_pred = self .mse_loss_scaler (hrrr_pred )
887
897
888
- # proposed loss gradient rescaler from section 4.3.2
898
+ hrrr_loss = F . mse_loss ( hrrr_pred , normed_hrrr_target )
889
899
890
- normed_hrrr_pred = self . mse_loss_scaler ( normed_hrrr_pred )
900
+ # update hrrr normalization statistics, if using batchnorm way
891
901
892
- hrrr_loss = F .mse_loss (normed_hrrr_pred , normed_hrrr_target )
902
+ if not self .has_hrrr_norm_statistics and self .training :
903
+ _ = self .batchnorm_hrrr (hrrr_target )
893
904
894
905
# total loss
895
906
0 commit comments