Skip to content

Commit 98f99ba

Browse files
committed
first pass at inputs, redo hrrr
1 parent 2e86e5b commit 98f99ba

File tree

3 files changed

+74
-60
lines changed

3 files changed

+74
-60
lines changed

README.md

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ metnet3 = MetNet3(
2828
lead_time_embed_dim = 32,
2929
input_spatial_size = 624,
3030
attn_dim_head = 8,
31-
sparse_input_2496_channels = 8,
32-
dense_input_2496_channels = 8,
33-
dense_input_4996_channels = 8,
31+
hrrr_channels = 617,
32+
input_2496_channels = 2 + 14 + 1 + 2 + 20,
33+
input_4996_channels = 16 + 1,
3434
precipitation_target_bins = dict(
3535
mrms_rate = 512,
3636
mrms_accumulation = 512,
@@ -43,16 +43,16 @@ metnet3 = MetNet3(
4343
omo_wind_component_y = 256,
4444
omo_wind_direction = 180
4545
),
46-
hrrr_loss_weight = 10,
47-
hrrr_target_channels = 617
46+
hrrr_loss_weight = 10
4847
)
4948

5049
# inputs
5150

5251
lead_times = torch.randint(0, 722, (2,))
53-
sparse_input_2496 = torch.randn((2, 8, 624, 624))
54-
dense_input_2496 = torch.randn((2, 8, 624, 624))
55-
dense_input_4996 = torch.randn((2, 8, 624, 624))
52+
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
53+
hrrr_stale_state = torch.randn((2, 1, 624, 624))
54+
input_2496 = torch.randn((2, 39, 624, 624))
55+
input_4996 = torch.randn((2, 17, 624, 624))
5656

5757
# targets
5858

@@ -74,9 +74,10 @@ hrrr_target = torch.randn(2, 617, 128, 128)
7474

7575
total_loss, loss_breakdown = metnet3(
7676
lead_times = lead_times,
77-
sparse_input_2496 = sparse_input_2496,
78-
dense_input_2496 = dense_input_2496,
79-
dense_input_4996 = dense_input_4996,
77+
hrrr_input_2496 = hrrr_input_2496,
78+
hrrr_stale_state = hrrr_stale_state,
79+
input_2496 = input_2496,
80+
input_4996 = input_4996,
8081
precipitation_targets = precipitation_targets,
8182
surface_targets = surface_targets,
8283
hrrr_target = hrrr_target
@@ -88,13 +89,15 @@ total_loss.backward()
8889

8990
metnet3.eval()
9091

91-
surface_targets, hrrr_target, precipitation_targets = metnet3(
92+
surface_preds, hrrr_pred, precipitation_preds = metnet3(
9293
lead_times = lead_times,
93-
sparse_input_2496 = sparse_input_2496,
94-
dense_input_2496 = dense_input_2496,
95-
dense_input_4996 = dense_input_4996
94+
hrrr_input_2496 = hrrr_input_2496,
95+
hrrr_stale_state = hrrr_stale_state,
96+
input_2496 = input_2496,
97+
input_4996 = input_4996,
9698
)
9799

100+
98101
# Dict[str, Tensor], Tensor, Dict[str, Tensor]
99102
```
100103

@@ -103,8 +106,8 @@ surface_targets, hrrr_target, precipitation_targets = metnet3(
103106
- [x] figure out all the cross entropy and MSE losses
104107
- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
105108
- [x] allow researcher to pass in their own normalization variables for HRRR
109+
- [x] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
106110

107-
- [ ] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
108111
- [ ] make sure model can be easily saved and loaded, with different ways of handling hrrr norm
109112
- [ ] figure out the topological embedding, consult a neural weather researcher
110113

metnet3_pytorch/metnet3_pytorch.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,8 @@ def __init__(
594594
vit_window_size = 8,
595595
vit_mbconv_expansion_rate = 4,
596596
vit_mbconv_shrinkage_rate = 0.25,
597-
sparse_input_2496_channels = 8,
598-
dense_input_2496_channels = 8,
599-
dense_input_4996_channels = 8,
597+
input_2496_channels = 2 + 14 + 1 + 2 + 20,
598+
input_4996_channels = 16 + 1,
600599
surface_and_hrrr_target_spatial_size = 128,
601600
precipitation_target_bins: Dict[str, int] = dict(
602601
mrms_rate = 512,
@@ -610,26 +609,26 @@ def __init__(
610609
omo_wind_component_y = 256,
611610
omo_wind_direction = 180
612611
),
613-
hrrr_target_channels = 617,
612+
hrrr_channels = 617,
614613
hrrr_norm_statistics: Optional[Tensor] = None,
615614
hrrr_loss_weight = 10,
616615
crop_size_post_16km = 48,
617616
resnet_block_depth = 2,
618617
):
619618
super().__init__()
620-
self.sparse_input_2496_shape = (sparse_input_2496_channels, input_spatial_size, input_spatial_size)
621-
self.dense_input_2496_shape = (dense_input_2496_channels, input_spatial_size, input_spatial_size)
622-
self.dense_input_4996_shape = (dense_input_4996_channels, input_spatial_size, input_spatial_size)
619+
self.hrrr_input_2496_shape = (hrrr_channels, input_spatial_size, input_spatial_size)
620+
self.input_2496_shape = (input_2496_channels, input_spatial_size, input_spatial_size)
621+
self.input_4996_shape = (input_4996_channels, input_spatial_size, input_spatial_size)
623622

624623
self.surface_and_hrrr_target_spatial_size = surface_and_hrrr_target_spatial_size
625624

626625
self.surface_target_shape = ((self.surface_and_hrrr_target_spatial_size,) * 2)
627-
self.hrrr_target_shape = (hrrr_target_channels, *self.surface_target_shape)
626+
self.hrrr_target_shape = (hrrr_channels, *self.surface_target_shape)
628627
self.precipitation_target_shape = (surface_and_hrrr_target_spatial_size * 4,) * 2
629628

630629
self.lead_time_embedding = nn.Embedding(num_lead_times, lead_time_embed_dim)
631630

632-
dim_in_4km = sparse_input_2496_channels + dense_input_2496_channels
631+
dim_in_4km = hrrr_channels + input_2496_channels + 1
633632

634633
self.to_skip_connect_4km = CenterCrop(crop_size_post_16km * 4)
635634

@@ -645,7 +644,7 @@ def __init__(
645644
CenterPad(input_spatial_size)
646645
)
647646

648-
dim_in_8km = dense_input_4996_channels + dim
647+
dim_in_8km = input_4996_channels + dim
649648

650649
self.resnet_blocks_down_8km = ResnetBlocks(
651650
dim = dim,
@@ -726,53 +725,79 @@ def __init__(
726725

727726
self.to_hrrr_pred = Sequential(
728727
ChanLayerNorm(dim),
729-
nn.Conv2d(dim, hrrr_target_channels, 1)
728+
nn.Conv2d(dim, hrrr_channels, 1)
730729
)
731730

732731
# they scale hrrr loss by 10. but also divided by number of channels
733732

734-
self.hrrr_loss_weight = hrrr_loss_weight / hrrr_target_channels
733+
self.hrrr_loss_weight = hrrr_loss_weight / hrrr_channels
735734

736735
self.has_hrrr_norm_statistics = exists(hrrr_norm_statistics)
737736

738737
if self.has_hrrr_norm_statistics:
739-
assert hrrr_norm_statistics.shape == (2, hrrr_target_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
738+
assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
740739
self.register_buffer('hrrr_norm_statistics', hrrr_norm_statistics)
741740
else:
742-
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_target_channels, affine = False)
741+
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_channels, affine = False)
743742

744743
self.mse_loss_scaler = LossScaler()
745744

746745
@beartype
747746
def forward(
748747
self,
748+
*,
749749
lead_times,
750-
sparse_input_2496,
751-
dense_input_2496,
752-
dense_input_4996,
750+
hrrr_input_2496,
751+
hrrr_stale_state,
752+
input_2496,
753+
input_4996,
753754
surface_targets: Optional[Dict[str, Tensor]] = None,
754755
precipitation_targets: Optional[Dict[str, Tensor]] = None,
755756
hrrr_target: Optional[Tensor] = None,
756757
):
757758
batch = lead_times.shape[0]
758759

759-
assert batch == sparse_input_2496.shape[0] == dense_input_2496.shape[0] == dense_input_4996.shape[0], 'batch size across all inputs must be the same'
760+
assert batch == hrrr_input_2496.shape[0] == input_2496.shape[0] == input_4996.shape[0], 'batch size across all inputs must be the same'
760761

761-
assert sparse_input_2496.shape[1:] == self.sparse_input_2496_shape
762-
assert dense_input_2496.shape[1:] == self.dense_input_2496_shape
763-
assert dense_input_4996.shape[1:] == self.dense_input_4996_shape
762+
assert hrrr_input_2496.shape[1:] == self.hrrr_input_2496_shape
763+
assert input_2496.shape[1:] == self.input_2496_shape
764+
assert input_4996.shape[1:] == self.input_4996_shape
765+
766+
# normalize HRRR input and target as needed
767+
768+
if self.has_hrrr_norm_statistics:
769+
mean, variance = self.hrrr_norm_statistics
770+
mean = rearrange(mean, 'c -> c 1 1')
771+
variance = rearrange(variance, 'c -> c 1 1')
772+
inv_std = variance.clamp(min = 1e-5).rsqrt()
773+
774+
normed_hrrr_input = (hrrr_input_2496 - mean) * inv_std
775+
776+
if exists(hrrr_target):
777+
normed_hrrr_target = (hrrr_target - mean) * inv_std
778+
779+
else:
780+
# use a batchnorm to normalize each channel to mean zero and unit variance
781+
782+
with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
783+
normed_hrrr_input = frozen_batchnorm(hrrr_input_2496)
784+
785+
if exists(hrrr_target):
786+
normed_hrrr_target = frozen_batchnorm(hrrr_target)
787+
788+
# main network
764789

765790
cond = self.lead_time_embedding(lead_times)
766791

767-
x = torch.cat((sparse_input_2496, dense_input_2496), dim = 1)
792+
x = torch.cat((normed_hrrr_input, hrrr_stale_state, input_2496), dim = 1)
768793

769794
skip_connect_4km = self.to_skip_connect_4km(x)
770795

771796
x = self.resnet_blocks_down_4km(x, cond = cond)
772797

773798
x = self.downsample_and_pad_to_8km(x)
774799

775-
x = torch.cat((dense_input_4996, x), dim = 1)
800+
x = torch.cat((input_4996, x), dim = 1)
776801

777802
skip_connect_8km = self.to_skip_connect_8km(x)
778803

@@ -866,30 +891,16 @@ def forward(
866891
ce_losses = ce_losses + precipition_loss
867892

868893
# calculate HRRR mse loss
894+
# proposed loss gradient rescaler from section 4.3.2
869895

870-
if self.has_hrrr_norm_statistics:
871-
mean, variance = self.hrrr_norm_statistics
872-
mean = rearrange(mean, 'c -> c 1 1')
873-
variance = rearrange(variance, 'c -> c 1 1')
874-
inv_std = variance.clamp(min = 1e-5).rsqrt()
875-
876-
normed_hrrr_target = (hrrr_target - mean) * inv_std
877-
normed_hrrr_pred = (hrrr_pred - mean) * inv_std
878-
else:
879-
# use a batchnorm to normalize each channel to mean zero and unit variance
880-
881-
if self.training:
882-
_ = self.batchnorm_hrrr(hrrr_target)
883-
884-
with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
885-
normed_hrrr_pred = frozen_batchnorm(hrrr_pred)
886-
normed_hrrr_target = frozen_batchnorm(hrrr_target)
896+
hrrr_pred = self.mse_loss_scaler(hrrr_pred)
887897

888-
# proposed loss gradient rescaler from section 4.3.2
898+
hrrr_loss = F.mse_loss(hrrr_pred, normed_hrrr_target)
889899

890-
normed_hrrr_pred = self.mse_loss_scaler(normed_hrrr_pred)
900+
# update hrrr normalization statistics, if using batchnorm way
891901

892-
hrrr_loss = F.mse_loss(normed_hrrr_pred, normed_hrrr_target)
902+
if not self.has_hrrr_norm_statistics and self.training:
903+
_ = self.batchnorm_hrrr(hrrr_target)
893904

894905
# total loss
895906

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'metnet3-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.8',
6+
version = '0.0.9',
77
license='MIT',
88
description = 'MetNet 3 - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)