Skip to content

Commit f55a164

Browse files
committed
use earthkit-data
1 parent 9b28685 commit f55a164

File tree

10 files changed

+164
-136
lines changed

10 files changed

+164
-136
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,4 @@ cython_debug/
167167
?.*
168168
*.png
169169
*.pny
170+
_version.py

.pre-commit-config.yaml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
repos:
2+
3+
# Empty notebookds
4+
- repo: local
5+
hooks:
6+
- id: clear-notebooks-output
7+
name: clear-notebooks-output
8+
files: tools/.*\.ipynb$
9+
stages: [commit]
10+
language: python
11+
entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace
12+
additional_dependencies: [jupyter]
13+
14+
15+
- repo: https://github.com/pre-commit/pre-commit-hooks
16+
rev: v4.4.0
17+
hooks:
18+
- id: check-yaml # Check YAML files for syntax errors only
19+
args: [--unsafe, --allow-multiple-documents]
20+
- id: debug-statements # Check for debugger imports and py37+ breakpoint()
21+
- id: end-of-file-fixer # Ensure files end in a newline
22+
- id: trailing-whitespace # Trailing whitespace checker
23+
- id: no-commit-to-branch # Prevent committing to main / master
24+
- id: check-added-large-files # Check for large files added to git
25+
- id: check-merge-conflict # Check for files that contain merge conflict
26+
27+
- repo: https://github.com/psf/black-pre-commit-mirror
28+
rev: 24.1.1
29+
hooks:
30+
- id: black
31+
args: [--line-length=120]
32+
33+
- repo: https://github.com/pycqa/isort
34+
rev: 5.13.2
35+
hooks:
36+
- id: isort
37+
args:
38+
- -l 120
39+
- --force-single-line-imports
40+
- --profile black
41+
42+
43+
- repo: https://github.com/astral-sh/ruff-pre-commit
44+
rev: v0.3.0
45+
hooks:
46+
- id: ruff
47+
exclude: '(dev/.*|.*_)\.py$'
48+
args:
49+
- --line-length=120
50+
- --fix
51+
- --exit-non-zero-on-fix
52+
- --preview
53+
54+
- repo: https://github.com/sphinx-contrib/sphinx-lint
55+
rev: v0.9.1
56+
hooks:
57+
- id: sphinx-lint
58+
59+
# For now, we use it. But it does not support a lot of sphinx features
60+
- repo: https://github.com/dzhu/rstfmt
61+
rev: v0.0.14
62+
hooks:
63+
- id: rstfmt
64+
65+
- repo: https://github.com/b8raoult/pre-commit-docconvert
66+
rev: "0.1.4"
67+
hooks:
68+
- id: docconvert
69+
args: ["numpy"]

pyproject.toml

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
# (C) Copyright 2024 ECMWF.
3+
#
4+
# This software is licensed under the terms of the Apache Licence Version 2.0
5+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/
11+
12+
[build-system]
13+
requires = [
14+
"setuptools>=60",
15+
"setuptools-scm>=8",
16+
]
17+
18+
[project]
19+
name = "ai-models-graphcast"
20+
21+
description = "An ai-models plugin to run Deepmind's graphcast model"
22+
keywords = [
23+
"ai",
24+
"tools",
25+
]
26+
27+
license = { file = "LICENSE" }
28+
authors = [
29+
{ name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "[email protected]" },
30+
]
31+
32+
requires-python = ">=3.10"
33+
34+
classifiers = [
35+
"Development Status :: 4 - Beta",
36+
"Intended Audience :: Developers",
37+
"License :: OSI Approved :: Apache Software License",
38+
"Operating System :: OS Independent",
39+
"Programming Language :: Python :: 3 :: Only",
40+
"Programming Language :: Python :: 3.9",
41+
"Programming Language :: Python :: 3.10",
42+
"Programming Language :: Python :: 3.11",
43+
"Programming Language :: Python :: 3.12",
44+
"Programming Language :: Python :: Implementation :: CPython",
45+
"Programming Language :: Python :: Implementation :: PyPy",
46+
]
47+
48+
dynamic = [
49+
"version",
50+
]
51+
52+
# JAX requirements are in requirements.txt
53+
54+
dependencies = [
55+
"ai-models>=0.4.0",
56+
"dm-tree",
57+
"dm-haiku==0.0.10",
58+
]
59+
60+
optional-dependencies.dev = [
61+
"pre-commit",
62+
]
63+
urls.Repository = "https://github.com/ecmwf-lab/ai-models-graphcast"
64+
entry-points."ai_models.model".graphcast = "ai_models_graphcast.model:model"
65+
66+
[tool.setuptools_scm]
67+
version_file = "src/ai_models_graphcast/_version.py"

setup.py

Lines changed: 0 additions & 68 deletions
This file was deleted.
File renamed without changes.
File renamed without changes.

ai_models_graphcast/input.py renamed to src/ai_models_graphcast/input.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from collections import defaultdict
1212

13-
import climetlab as cml
13+
import earthkit.data as ekd
1414
import numpy as np
1515
import xarray as xr
1616

@@ -46,8 +46,8 @@ def forcing_variables_numpy(sample, forcing_variables, dates):
4646
Returns:
4747
torch.Tensor: Tensor with constants
4848
"""
49-
ds = cml.load_source(
50-
"constants",
49+
ds = ekd.from_source(
50+
"forcings",
5151
sample,
5252
date=dates,
5353
param=forcing_variables,
@@ -74,16 +74,13 @@ def create_training_xarray(
7474
):
7575
time_deltas = [
7676
datetime.timedelta(hours=h)
77-
for h in lagged
78-
+ [hour for hour in range(hour_steps, lead_time + hour_steps, hour_steps)]
77+
for h in lagged + [hour for hour in range(hour_steps, lead_time + hour_steps, hour_steps)]
7978
]
8079

81-
all_datetimes = [start_date() + time_delta for time_delta in time_deltas]
80+
all_datetimes = [start_date + time_delta for time_delta in time_deltas]
8281

8382
with timer("Creating forcing variables"):
84-
forcing_numpy = forcing_variables_numpy(
85-
fields_sfc, forcing_variables, all_datetimes
86-
)
83+
forcing_numpy = forcing_variables_numpy(fields_sfc, forcing_variables, all_datetimes)
8784

8885
with timer("Converting GRIB to xarray"):
8986
# Create Input dataset
@@ -118,9 +115,7 @@ def create_training_xarray(
118115
data_vars[CF_NAME_SFC[param]] = (["lat", "lon"], fields[0].to_numpy())
119116
continue
120117

121-
data = np.stack(
122-
[field.to_numpy(dtype=np.float32) for field in fields]
123-
).reshape(
118+
data = np.stack([field.to_numpy(dtype=np.float32) for field in fields]).reshape(
124119
1,
125120
len(given_datetimes),
126121
len(lat),
@@ -141,9 +136,7 @@ def create_training_xarray(
141136
data_vars[CF_NAME_SFC[param]] = (["batch", "time", "lat", "lon"], data)
142137

143138
for param, fields in pl.items():
144-
data = np.stack(
145-
[field.to_numpy(dtype=np.float32) for field in fields]
146-
).reshape(
139+
data = np.stack([field.to_numpy(dtype=np.float32) for field in fields]).reshape(
147140
1,
148141
len(given_datetimes),
149142
len(levels),
@@ -188,9 +181,7 @@ def create_training_xarray(
188181

189182
with timer("Reindexing"):
190183
# And we want the grid south to north
191-
training_xarray = training_xarray.reindex(
192-
lat=sorted(training_xarray.lat.values)
193-
)
184+
training_xarray = training_xarray.reindex(lat=sorted(training_xarray.lat.values))
194185

195186
if constants:
196187
# Add geopotential_at_surface and land_sea_mask back in

ai_models_graphcast/model.py renamed to src/ai_models_graphcast/model.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88

99
import dataclasses
10-
import datetime
1110
import functools
1211
import gc
1312
import logging
1413
import os
15-
from functools import cached_property
1614

1715
import xarray
1816
from ai_models.model import Model
@@ -26,14 +24,12 @@
2624
try:
2725
import haiku as hk
2826
import jax
29-
from graphcast import (
30-
autoregressive,
31-
casting,
32-
checkpoint,
33-
data_utils,
34-
graphcast,
35-
normalization,
36-
)
27+
from graphcast import autoregressive
28+
from graphcast import casting
29+
from graphcast import checkpoint
30+
from graphcast import data_utils
31+
from graphcast import graphcast
32+
from graphcast import normalization
3733
except ModuleNotFoundError as e:
3834
msg = "You need to install Graphcast from git to use this model. See README.md for details."
3935
LOG.error(msg)
@@ -88,9 +84,7 @@ def __init__(self, **kwargs):
8884
self.lagged = [-6, 0]
8985
self.params = None
9086
self.ordering = self.param_sfc + [
91-
f"{param}{level}"
92-
for param in self.param_level_pl[0]
93-
for level in self.param_level_pl[1]
87+
f"{param}{level}" for param in self.param_level_pl[0] for level in self.param_level_pl[1]
9488
]
9589

9690
# Jax doesn't seem to like passing configs as args through the jit. Passing it
@@ -119,17 +113,11 @@ def load_model(self):
119113
def get_path(filename):
120114
return os.path.join(self.assets, filename)
121115

122-
diffs_stddev_by_level = xarray.load_dataset(
123-
get_path(self.download_files[1])
124-
).compute()
116+
diffs_stddev_by_level = xarray.load_dataset(get_path(self.download_files[1])).compute()
125117

126-
mean_by_level = xarray.load_dataset(
127-
get_path(self.download_files[2])
128-
).compute()
118+
mean_by_level = xarray.load_dataset(get_path(self.download_files[2])).compute()
129119

130-
stddev_by_level = xarray.load_dataset(
131-
get_path(self.download_files[3])
132-
).compute()
120+
stddev_by_level = xarray.load_dataset(get_path(self.download_files[3])).compute()
133121

134122
def construct_wrapped_graphcast(model_config, task_config):
135123
"""Constructs and wraps the GraphCast Predictor."""
@@ -183,13 +171,7 @@ def run_forward(
183171
LOG.info("Model license: %s", self.ckpt.license)
184172

185173
jax.jit(self._with_configs(run_forward.init))
186-
self.model = self._drop_state(
187-
self._with_params(jax.jit(self._with_configs(run_forward.apply)))
188-
)
189-
190-
@cached_property
191-
def start_date(self) -> "datetime":
192-
return self.all_fields.order_by(valid_datetime="descending")[0].datetime
174+
self.model = self._drop_state(self._with_params(jax.jit(self._with_configs(run_forward.apply))))
193175

194176
def run(self):
195177
# We ignore 'tp' so that we make sure that step 0 is a field of zero values
@@ -205,7 +187,7 @@ def run(self):
205187
fields_sfc=self.fields_sfc,
206188
fields_pl=self.fields_pl,
207189
lagged=self.lagged,
208-
start_date=self.start_date,
190+
start_date=self.start_datetime,
209191
hour_steps=self.hour_steps,
210192
lead_time=self.lead_time,
211193
forcing_variables=self.forcing_variables,
@@ -226,8 +208,7 @@ def run(self):
226208
) = data_utils.extract_inputs_targets_forcings(
227209
training_xarray,
228210
target_lead_times=[
229-
f"{int(delta.days * 24 + delta.seconds/3600):d}h"
230-
for delta in time_deltas[len(self.lagged) :]
211+
f"{int(delta.days * 24 + delta.seconds/3600):d}h" for delta in time_deltas[len(self.lagged) :]
231212
],
232213
**dataclasses.asdict(self.task_config),
233214
)

0 commit comments

Comments
 (0)