11
11
import warnings
12
12
from typing import Optional
13
13
14
- from anemoi .utils .data_structures import NumpyNestedAnemoiTensor
14
+ from anemoi .utils .data_structures import NestedTrainingSample , NumpyNestedAnemoiTensor
15
15
16
16
import numpy as np
17
17
import torch
@@ -136,7 +136,7 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min
136
136
], f"{ method } is not a valid normalisation method"
137
137
138
138
def transform (
139
- self , x : torch .Tensor , in_place : bool , data_index : torch .Tensor ,
139
+ self , x : torch .Tensor , in_place : bool , data_index : torch .Tensor = None ,
140
140
) -> torch .Tensor :
141
141
"""Normalizes an input tensor x of shape [..., nvars].
142
142
@@ -160,10 +160,19 @@ def transform(
160
160
_description_
161
161
"""
162
162
if not in_place :
163
- x = x .clone () # TODO: fix this; implement a custom clone() op?
163
+ x = x .clone ()
164
+
165
+ assert isinstance (x , NestedTrainingSample ), type (x )
166
+ assert data_index is None
167
+
168
+ # should be a method
169
+ for s in x :
170
+ for k ,v in s .items ():
171
+ norm_mul = getattr (self , "_norm_mul" + f"__{ k } " )
172
+ norm_add = getattr (self , "_norm_add" + f"__{ k } " )
173
+ v [..., :] = v [..., :] * norm_mul + norm_add
174
+ print ('Normalisation done. OK.' )
164
175
165
- assert data_index is not None # [Mihai] we require a data_index
166
- x [..., :] = x [..., :] * self ._norm_mul [data_index ] + self ._norm_add [data_index ]
167
176
return x
168
177
169
178
def inverse_transform (
0 commit comments