Skip to content

Commit e898b76

Browse files
committed
redo the hrrr norm strategy to be completely flexible
1 parent 98f99ba commit e898b76

File tree

3 files changed

+81
-13
lines changed

3 files changed

+81
-13
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ metnet3 = MetNet3(
4343
omo_wind_component_y = 256,
4444
omo_wind_direction = 180
4545
),
46-
hrrr_loss_weight = 10
46+
hrrr_loss_weight = 10,
47+
hrrr_norm_strategy = 'sync_batchnorm', # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
48+
hrrr_norm_statistics = None # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
4749
)
4850

4951
# inputs
@@ -107,8 +109,8 @@ surface_preds, hrrr_pred, precipitation_preds = metnet3(
107109
- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
108110
- [x] allow researcher to pass in their own normalization variables for HRRR
109111
- [x] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
112+
- [x] make sure model can be easily saved and loaded, with different ways of handling hrrr norm
110113

111-
- [ ] make sure model can be easily saved and loaded, with different ways of handling hrrr norm
112114
- [ ] figure out the topological embedding, consult a neural weather researcher
113115

114116
## Citations

metnet3_pytorch/metnet3_pytorch.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
from contextlib import contextmanager
23
from functools import partial
34
from collections import namedtuple
@@ -13,7 +14,9 @@
1314
from einops.layers.torch import Rearrange, Reduce
1415

1516
from beartype import beartype
16-
from beartype.typing import Tuple, Union, List, Optional, Dict
17+
from beartype.typing import Tuple, Union, List, Optional, Dict, Literal
18+
19+
import pickle
1720

1821
# helpers
1922

@@ -609,13 +612,27 @@ def __init__(
609612
omo_wind_component_y = 256,
610613
omo_wind_direction = 180
611614
),
615+
hrrr_norm_strategy: Union[
616+
Literal['none'],
617+
Literal['precalculated'],
618+
Literal['sync_batchnorm']
619+
] = 'none',
612620
hrrr_channels = 617,
613621
hrrr_norm_statistics: Optional[Tensor] = None,
614622
hrrr_loss_weight = 10,
615623
crop_size_post_16km = 48,
616624
resnet_block_depth = 2,
617625
):
618626
super().__init__()
627+
628+
# for autosaving the config
629+
630+
_locals = locals()
631+
_locals.pop('self', None)
632+
_locals.pop('__class__', None)
633+
_locals.pop('hrrr_norm_statistics', None)
634+
self._configs = pickle.dumps(_locals)
635+
619636
self.hrrr_input_2496_shape = (hrrr_channels, input_spatial_size, input_spatial_size)
620637
self.input_2496_shape = (input_2496_channels, input_spatial_size, input_spatial_size)
621638
self.input_4996_shape = (input_4996_channels, input_spatial_size, input_spatial_size)
@@ -732,15 +749,60 @@ def __init__(
732749

733750
self.hrrr_loss_weight = hrrr_loss_weight / hrrr_channels
734751

735-
self.has_hrrr_norm_statistics = exists(hrrr_norm_statistics)
752+
self.mse_loss_scaler = LossScaler()
753+
754+
# norm statistics
736755

737-
if self.has_hrrr_norm_statistics:
738-
assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset'
756+
default_hrrr_statistics = torch.empty((2, hrrr_channels), dtype = torch.float32)
757+
758+
if hrrr_norm_strategy == 'none':
759+
self.register_buffer('hrrr_norm_statistics', default_hrrr_statistics, persistent = False)
760+
761+
elif hrrr_norm_strategy == 'precalculated':
762+
assert exists(hrrr_norm_statistics), 'hrrr_norm_statistics must be passed in, if normalizing input hrrr as well as target with precalculated dataset mean and variance'
763+
assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {hrrr_channels}), containing mean and variance of each target calculated from the dataset'
739764
self.register_buffer('hrrr_norm_statistics', hrrr_norm_statistics)
740-
else:
765+
766+
elif hrrr_norm_strategy == 'sync_batchnorm':
767+
self.register_buffer('hrrr_norm_statistics', default_hrrr_statistics, persistent = False)
741768
self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_channels, affine = False)
742769

743-
self.mse_loss_scaler = LossScaler()
770+
self.hrrr_norm_strategy = hrrr_norm_strategy
771+
772+
@classmethod
773+
def init_and_load_from(cls, path, strict = True):
774+
path = Path(path)
775+
assert path.exists()
776+
pkg = torch.load(str(path), map_location = 'cpu')
777+
778+
assert 'config' in pkg, 'model configs were not found in this saved checkpoint'
779+
780+
config = pickle.loads(pkg['config'])
781+
tokenizer = cls(**config)
782+
tokenizer.load(path, strict = strict)
783+
return tokenizer
784+
785+
def save(self, path, overwrite = True):
786+
path = Path(path)
787+
assert overwrite or not path.exists(), f'{str(path)} already exists'
788+
789+
pkg = dict(
790+
model_state_dict = self.state_dict(),
791+
config = self._configs
792+
)
793+
794+
torch.save(pkg, str(path))
795+
796+
def load(self, path, strict = True):
797+
path = Path(path)
798+
assert path.exists()
799+
800+
pkg = torch.load(str(path))
801+
state_dict = pkg.get('model_state_dict')
802+
803+
assert exists(state_dict)
804+
805+
self.load_state_dict(state_dict, strict = strict)
744806

745807
@beartype
746808
def forward(
@@ -763,9 +825,9 @@ def forward(
763825
assert input_2496.shape[1:] == self.input_2496_shape
764826
assert input_4996.shape[1:] == self.input_4996_shape
765827

766-
# normalize HRRR input and target as needed
828+
# normalize HRRR input and target, if needed
767829

768-
if self.has_hrrr_norm_statistics:
830+
if self.hrrr_norm_strategy == 'precalculated':
769831
mean, variance = self.hrrr_norm_statistics
770832
mean = rearrange(mean, 'c -> c 1 1')
771833
variance = rearrange(variance, 'c -> c 1 1')
@@ -776,7 +838,7 @@ def forward(
776838
if exists(hrrr_target):
777839
normed_hrrr_target = (hrrr_target - mean) * inv_std
778840

779-
else:
841+
elif self.hrrr_norm_strategy == 'sync_batchnorm':
780842
# use a batchnorm to normalize each channel to mean zero and unit variance
781843

782844
with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
@@ -785,6 +847,10 @@ def forward(
785847
if exists(hrrr_target):
786848
normed_hrrr_target = frozen_batchnorm(hrrr_target)
787849

850+
elif self.hrrr_norm_strategy == 'none':
851+
if exists(hrrr_target):
852+
normed_hrrr_target = hrrr_target
853+
788854
# main network
789855

790856
cond = self.lead_time_embedding(lead_times)
@@ -899,7 +965,7 @@ def forward(
899965

900966
# update hrrr normalization statistics, if using batchnorm way
901967

902-
if not self.has_hrrr_norm_statistics and self.training:
968+
if self.training and self.hrrr_norm_strategy == 'sync_batchnorm':
903969
_ = self.batchnorm_hrrr(hrrr_target)
904970

905971
# total loss

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'metnet3-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.9',
6+
version = '0.0.10',
77
license='MIT',
88
description = 'MetNet 3 - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)