Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit 489a241

Browse files
Feature/flexible remappers (#88)
* Add one-to-one mapper -> Monomapper * Remapper (one-to-many mapper) is now called Multimapper --------- Co-authored-by: Sara Hahner <[email protected]>
1 parent 225315e commit 489a241

File tree

7 files changed

+664
-313
lines changed

7 files changed

+664
-313
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ Keep it human-readable, your future self will thank you!
3333
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
3434
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
3535
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)
36+
- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88]
37+
- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97)
3638

3739
### Changed
3840
- Bugfixes for CI

src/anemoi/models/preprocessing/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,32 @@ def __init__(
5757

5858
super().__init__()
5959

60-
self.default, self.method_config = self._process_config(config)
60+
self.default, self.remap, self.method_config = self._process_config(config)
6161
self.methods = self._invert_key_value_list(self.method_config)
6262

6363
self.data_indices = data_indices
6464

65-
def _process_config(self, config):
65+
@classmethod
66+
def _process_config(cls, config):
6667
_special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method.
6768
default = config.get("default", "none")
68-
self.remap = config.get("remap", {})
69+
remap = config.get("remap", {})
6970
method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"}
7071

7172
if not method_config:
7273
LOGGER.warning(
73-
f"{self.__class__.__name__}: Using default method {default} for all variables not specified in the config.",
74+
f"{cls.__name__}: Using default method {default} for all variables not specified in the config.",
7475
)
76+
for m in method_config:
77+
if isinstance(method_config[m], str):
78+
method_config[m] = {method_config[m]: f"{m}_{method_config[m]}"}
79+
elif isinstance(method_config[m], list):
80+
method_config[m] = {method: f"{m}_{method}" for method in method_config[m]}
7581

76-
return default, method_config
82+
return default, remap, method_config
7783

78-
def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[str, str]:
84+
@staticmethod
85+
def _invert_key_value_list(method_config: dict[str, list[str]]) -> dict[str, str]:
7986
"""Invert a dictionary of methods with lists of variables.
8087
8188
Parameters
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import torch
11+
12+
13+
def noop(x):
14+
"""No operation."""
15+
return x
16+
17+
18+
def cos_converter(x):
19+
"""Convert angle in degree to cos."""
20+
return torch.cos(x / 180 * torch.pi)
21+
22+
23+
def sin_converter(x):
24+
"""Convert angle in degree to sin."""
25+
return torch.sin(x / 180 * torch.pi)
26+
27+
28+
def atan2_converter(x):
29+
"""Convert cos and sin to angle in degree.
30+
31+
Input:
32+
x[..., 0]: cos
33+
x[..., 1]: sin
34+
"""
35+
return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360)
36+
37+
38+
def log1p_converter(x):
39+
"""Convert positive var in to log(1+var)."""
40+
return torch.log1p(x)
41+
42+
43+
def boxcox_converter(x, lambd=0.5):
44+
"""Convert positive var in to boxcox(var)."""
45+
pos_lam = (torch.pow(x, lambd) - 1) / lambd
46+
null_lam = torch.log(x)
47+
if lambd == 0:
48+
return null_lam
49+
else:
50+
return pos_lam
51+
52+
53+
def sqrt_converter(x):
54+
"""Convert positive var in to sqrt(var)."""
55+
return torch.sqrt(x)
56+
57+
58+
def expm1_converter(x):
59+
"""Convert back log(1+var) to var."""
60+
return torch.expm1(x)
61+
62+
63+
def square_converter(x):
64+
"""Convert back sqrt(var) to var."""
65+
return x**2
66+
67+
68+
def inverse_boxcox_converter(x, lambd=0.5):
69+
"""Convert back boxcox(var) to var."""
70+
pos_lam = torch.pow(x * lambd + 1, 1 / lambd)
71+
null_lam = torch.exp(x)
72+
if lambd == 0:
73+
return null_lam
74+
else:
75+
return pos_lam
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
from abc import ABC
13+
from typing import Optional
14+
15+
import torch
16+
17+
from anemoi.models.data_indices.collection import IndexCollection
18+
from anemoi.models.preprocessing import BasePreprocessor
19+
from anemoi.models.preprocessing.mappings import boxcox_converter
20+
from anemoi.models.preprocessing.mappings import expm1_converter
21+
from anemoi.models.preprocessing.mappings import inverse_boxcox_converter
22+
from anemoi.models.preprocessing.mappings import log1p_converter
23+
from anemoi.models.preprocessing.mappings import noop
24+
from anemoi.models.preprocessing.mappings import sqrt_converter
25+
from anemoi.models.preprocessing.mappings import square_converter
26+
27+
LOGGER = logging.getLogger(__name__)
28+
29+
30+
class Monomapper(BasePreprocessor, ABC):
31+
"""Remap and convert variables for single variables."""
32+
33+
supported_methods = {
34+
method: [f, inv]
35+
for method, f, inv in zip(
36+
["log1p", "sqrt", "boxcox", "none"],
37+
[log1p_converter, sqrt_converter, boxcox_converter, noop],
38+
[expm1_converter, square_converter, inverse_boxcox_converter, noop],
39+
)
40+
}
41+
42+
def __init__(
43+
self,
44+
config=None,
45+
data_indices: Optional[IndexCollection] = None,
46+
statistics: Optional[dict] = None,
47+
) -> None:
48+
super().__init__(config, data_indices, statistics)
49+
self._create_remapping_indices(statistics)
50+
self._validate_indices()
51+
52+
def _validate_indices(self):
53+
assert (
54+
len(self.index_training_input)
55+
== len(self.index_inference_input)
56+
== len(self.index_inference_output)
57+
== len(self.index_training_out)
58+
== len(self.remappers)
59+
), (
60+
f"Error creating conversion indices {len(self.index_training_input)}, "
61+
f"{len(self.index_inference_input)}, {len(self.index_training_input)}, {len(self.index_training_out)}, {len(self.remappers)}"
62+
)
63+
64+
def _create_remapping_indices(
65+
self,
66+
statistics=None,
67+
):
68+
"""Create the parameter indices for remapping."""
69+
# list for training and inference mode as position of parameters can change
70+
name_to_index_training_input = self.data_indices.data.input.name_to_index
71+
name_to_index_inference_input = self.data_indices.model.input.name_to_index
72+
name_to_index_training_output = self.data_indices.data.output.name_to_index
73+
name_to_index_inference_output = self.data_indices.model.output.name_to_index
74+
self.num_training_input_vars = len(name_to_index_training_input)
75+
self.num_inference_input_vars = len(name_to_index_inference_input)
76+
self.num_training_output_vars = len(name_to_index_training_output)
77+
self.num_inference_output_vars = len(name_to_index_inference_output)
78+
79+
(
80+
self.remappers,
81+
self.backmappers,
82+
self.index_training_input,
83+
self.index_training_out,
84+
self.index_inference_input,
85+
self.index_inference_output,
86+
) = (
87+
[],
88+
[],
89+
[],
90+
[],
91+
[],
92+
[],
93+
)
94+
95+
# Create parameter indices for remapping variables
96+
for name in name_to_index_training_input:
97+
method = self.methods.get(name, self.default)
98+
if method in self.supported_methods:
99+
self.remappers.append(self.supported_methods[method][0])
100+
self.backmappers.append(self.supported_methods[method][1])
101+
self.index_training_input.append(name_to_index_training_input[name])
102+
if name in name_to_index_training_output:
103+
self.index_training_out.append(name_to_index_training_output[name])
104+
else:
105+
self.index_training_out.append(None)
106+
if name in name_to_index_inference_input:
107+
self.index_inference_input.append(name_to_index_inference_input[name])
108+
else:
109+
self.index_inference_input.append(None)
110+
if name in name_to_index_inference_output:
111+
self.index_inference_output.append(name_to_index_inference_output[name])
112+
else:
113+
# this is a forcing variable. It is not in the inference output.
114+
self.index_inference_output.append(None)
115+
else:
116+
raise KeyError[f"Unknown remapping method for {name}: {method}"]
117+
118+
def transform(self, x, in_place: bool = True) -> torch.Tensor:
119+
if not in_place:
120+
x = x.clone()
121+
if x.shape[-1] == self.num_training_input_vars:
122+
idx = self.index_training_input
123+
elif x.shape[-1] == self.num_inference_input_vars:
124+
idx = self.index_inference_input
125+
else:
126+
raise ValueError(
127+
f"Input tensor ({x.shape[-1]}) does not match the training "
128+
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
129+
)
130+
for i, remapper in zip(idx, self.remappers):
131+
if i is not None:
132+
x[..., i] = remapper(x[..., i])
133+
return x
134+
135+
def inverse_transform(self, x, in_place: bool = True) -> torch.Tensor:
136+
if not in_place:
137+
x = x.clone()
138+
if x.shape[-1] == self.num_training_output_vars:
139+
idx = self.index_training_out
140+
elif x.shape[-1] == self.num_inference_output_vars:
141+
idx = self.index_inference_output
142+
else:
143+
raise ValueError(
144+
f"Input tensor ({x.shape[-1]}) does not match the training "
145+
f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})",
146+
)
147+
for i, backmapper in zip(idx, self.backmappers):
148+
if i is not None:
149+
x[..., i] = backmapper(x[..., i])
150+
return x

0 commit comments

Comments
 (0)