-
Notifications
You must be signed in to change notification settings - Fork 0
refactoring branch such that ML model related #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dtensor
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,5 +5,5 @@ module load gcc/12.2.0 cuda/12.1 openmpi/4.1.5-cuda121-gdr ucx/1.14.0-gdr \ | |
# openpmd/0.15.2-cuda121-blosc2-py3122 | ||
# for (re-)instaling openpmd-api | ||
export openPMD_USE_MPI=ON | ||
source /home/kelling/checkout/insitumlNp2Torch26Env/bin/activate | ||
source /home/pandit52/venvs/Ism/bin/activate | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No edit wars: do not commit local config changes (at least not into PR) . |
||
export PMIX_MCA_gds=hash |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
vedhasua marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Model-related modules for openpmd-learning. | ||
|
||
This package contains: | ||
- Model architectures | ||
- Model initialization and loading functions | ||
- Configuration utilities | ||
""" | ||
|
||
from models.architectures import ModelFinal | ||
from models.model_factory import load_objects, get_VAE_encoder_kwargs, get_VAE_decoder_kwargs | ||
|
||
__all__ = [ | ||
'ModelFinal', | ||
'load_objects', | ||
'get_VAE_encoder_kwargs', | ||
'get_VAE_decoder_kwargs', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch.nn as nn | ||
from inSituML.loss_functions import EarthMoversLoss | ||
from inSituML.ks_models import INNModel | ||
vedhasua marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class ModelFinal(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having a more generic concept of the model to be trained is great. Note though, this class remains a somewhat specific instance, i.e. it has encoder, decoder, inner model. Please rename to something more descriptive, to respect this. |
||
def __init__( | ||
self, | ||
base_network, | ||
inner_model, | ||
loss_function_IM=None, | ||
weight_AE=1.0, | ||
weight_IM=1.0, | ||
): | ||
super().__init__() | ||
|
||
self.base_network = base_network | ||
self.inner_model = inner_model | ||
self.loss_function_IM = loss_function_IM | ||
self.weight_AE = weight_AE | ||
self.weight_IM = weight_IM | ||
|
||
def forward(self, x, y): | ||
|
||
loss_AE, loss_ae_reconst, kl_loss, _, encoded = self.base_network( | ||
x | ||
) | ||
|
||
# Check if the inner model is an instance of INNModel | ||
if isinstance(self.inner_model, INNModel): | ||
# Use the compute_losses function of INNModel | ||
(loss_IM, l_fit, l_latent, l_rev) = ( | ||
self.inner_model.compute_losses(encoded, y) | ||
) | ||
total_loss = ( | ||
loss_AE * self.weight_AE + loss_IM * self.weight_IM | ||
) | ||
|
||
losses = { | ||
"total_loss": total_loss, | ||
"loss_AE": loss_AE * self.weight_AE, | ||
"loss_IM": loss_IM * self.weight_IM, | ||
"loss_ae_reconst": loss_ae_reconst, | ||
"kl_loss": kl_loss, | ||
"l_fit": l_fit, | ||
"l_latent": l_latent, | ||
"l_rev": l_rev, | ||
} | ||
|
||
return losses | ||
else: | ||
# For other types of models, such as MAF | ||
loss_IM = self.inner_model(inputs=encoded, context=y) | ||
total_loss = ( | ||
loss_AE * self.weight_AE + loss_IM * self.weight_IM | ||
) | ||
|
||
losses = { | ||
"total_loss": total_loss, | ||
"loss_AE": loss_AE * self.weight_AE, | ||
"loss_IM": loss_IM * self.weight_IM, | ||
"loss_ae_reconst": loss_ae_reconst, | ||
"kl_loss": kl_loss, | ||
} | ||
|
||
return losses | ||
|
||
def reconstruct(self, x, y, num_samples=1): | ||
|
||
if isinstance(self.inner_model, INNModel): | ||
lat_z_pred = self.inner_model(x, y, rev=True) | ||
y = self.base_network.decoder(lat_z_pred) | ||
else: | ||
lat_z_pred = self.inner_model.sample_pointcloud( | ||
num_samples=num_samples, cond=y | ||
) | ||
y = self.base_network.decoder(lat_z_pred) | ||
|
||
return y, lat_z_pred |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import os, sys | ||
import torch | ||
import torch.optim as optim | ||
from inSituML.utilities import MMD_multiscale, fit, load_checkpoint | ||
from inSituML.ks_models import INNModel | ||
|
||
from inSituML.args_transform import MAPPING_TO_LOSS | ||
from inSituML.encoder_decoder import Encoder | ||
from inSituML.encoder_decoder import Conv3DDecoder | ||
Comment on lines
+8
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Encoder and Decoder class should be declared in model_config (not wholly defined, just imported from model library in general) |
||
from inSituML.loss_functions import EarthMoversLoss | ||
from inSituML.networks import VAE | ||
from models.architectures import ModelFinal | ||
|
||
def get_world_size(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to model definitions. The main program should still handle obtaining this value. The function can be moved into a utility module, but it should be a module with parallel-training utilities. |
||
"""Get the world size for distributed training.""" | ||
world_size = None | ||
if "WORLD_SIZE" in os.environ: | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
elif "SLURM_NTASKS" in os.environ: | ||
print( | ||
( | ||
"[WW] WORLD_SIZE not defined in env, " | ||
+ "falling back to SLURM_NTASKS." | ||
), | ||
file=sys.stderr, | ||
) | ||
world_size = int(os.environ["SLURM_NTASKS"]) | ||
else: | ||
raise RuntimeError("cannot determine WORLD_SIZE") | ||
return world_size | ||
|
||
def get_VAE_encoder_kwargs(io_config, model_config): | ||
"""Create encoder kwargs dictionary from configs""" | ||
return { | ||
"ae_config": "non_deterministic", | ||
"z_dim": model_config.latent_space_dims, | ||
"input_dim": io_config.ps_dims, | ||
"conv_layer_config": [16, 32, 64, 128, 256, 608], | ||
"conv_add_bn": False, | ||
"fc_layer_config": [544], | ||
} | ||
Comment on lines
+34
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This whole dictionary should be defined in model_config. |
||
|
||
def get_VAE_decoder_kwargs(io_config, model_config): | ||
"""Create decoder kwargs dictionary from configs""" | ||
return { | ||
"z_dim": model_config.latent_space_dims, | ||
"input_dim": io_config.ps_dims, | ||
"initial_conv3d_size": [16, 4, 4, 4], | ||
"add_batch_normalisation": False, | ||
"fc_layer_config": [1024], | ||
} | ||
Comment on lines
+45
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This whole thing should come from model_config, too. |
||
|
||
def load_objects(rank, io_config, model_config, world_size): | ||
"""Load and initialize model objects, optimizer, and scheduler.""" | ||
|
||
# Get configuration | ||
config = model_config.config | ||
|
||
# Get model parameters | ||
VAE_encoder_kwargs = get_VAE_encoder_kwargs(io_config, model_config) | ||
VAE_decoder_kwargs = get_VAE_decoder_kwargs(io_config, model_config) | ||
|
||
# Initialize loss function | ||
loss_fn_for_VAE = MAPPING_TO_LOSS[ | ||
model_config.config["loss_function"] | ||
](**model_config.config["loss_kwargs"]) | ||
|
||
# Initialize VAE | ||
VAE_obj = VAE( | ||
encoder=Encoder, | ||
encoder_kwargs=VAE_encoder_kwargs, | ||
decoder=Conv3DDecoder, | ||
z_dim=model_config.latent_space_dims, | ||
decoder_kwargs=VAE_decoder_kwargs, | ||
loss_function=loss_fn_for_VAE, | ||
property_="momentum_force", | ||
particles_to_sample=io_config.number_of_particles, | ||
ae_config="non_deterministic", | ||
use_encoding_in_decoder=False, | ||
weight_kl=model_config.config["lambd_kl"], | ||
device=rank, | ||
) | ||
|
||
# Initialize inner model | ||
inner_model = INNModel( | ||
ndim_tot=config["ndim_tot"], | ||
ndim_x=config["ndim_x"], | ||
ndim_y=config["ndim_y"], | ||
ndim_z=config["ndim_z"], | ||
loss_fit=fit, | ||
loss_latent=MMD_multiscale, | ||
loss_backward=MMD_multiscale, | ||
lambd_predict=config["lambd_predict"], | ||
lambd_latent=config["lambd_latent"], | ||
lambd_rev=config["lambd_rev"], | ||
zeros_noise_scale=config["zeros_noise_scale"], | ||
y_noise_scale=config["y_noise_scale"], | ||
hidden_size=config["hidden_size"], | ||
activation=config["activation"], | ||
num_coupling_layers=config["num_coupling_layers"], | ||
Comment on lines
+86
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These params should be from a dictionary in model_config. Most of them already are, maybe this could be renamed (more descriptive than "config", maybe "inner_model_config"). The details can then also be left out here, by just passing |
||
device=rank, | ||
) | ||
|
||
# Initialize final model | ||
model = ModelFinal( | ||
VAE_obj, | ||
inner_model, | ||
EarthMoversLoss(), | ||
vedhasua marked this conversation as resolved.
Show resolved
Hide resolved
|
||
weight_AE=config["lambd_AE"], | ||
weight_IM=config["lambd_IM"], | ||
) | ||
|
||
# Load a pre-trained model | ||
map_location = {"cuda:%d" % 0: "cuda:%d" % rank} | ||
if config["load_model"] is not None: | ||
original_state_dict = torch.load( | ||
config["load_model"], map_location=map_location | ||
) | ||
# updated_state_dict = {key.replace('VAE.', 'base_network.'): | ||
# value for key, value in original_state_dict.items()} | ||
model.load_state_dict(original_state_dict) | ||
print("Loaded pre-trained model successfully", flush=True) | ||
|
||
elif config["load_model_checkpoint"] is not None: | ||
model, _, _, _, _, _ = load_checkpoint( | ||
config["load_model_checkpoint"], | ||
model, | ||
map_location=map_location, | ||
) | ||
print("Loaded model checkpoint successfully", flush=True) | ||
else: | ||
pass # run with random init | ||
|
||
lr = config["lr"] | ||
bs_factor = ( | ||
io_config.trainBatchBuffer_config["training_bs"] / 2 * world_size | ||
) | ||
lr = lr * config["lr_scaling"](bs_factor) | ||
print( | ||
"Scaling learning rate from {} to {} due to bs factor {}".format( | ||
config["lr"], lr, bs_factor | ||
), | ||
flush=True, | ||
) | ||
|
||
optimizer = optim.Adam( | ||
[ | ||
{ | ||
"params": model.base_network.parameters(), | ||
"lr": lr * config["lrAEmult"], | ||
}, | ||
{"params": model.inner_model.parameters()}, | ||
], # model.parameters() | ||
lr=lr, | ||
betas=config["betas"], | ||
eps=config["eps"], | ||
weight_decay=config["weight_decay"], | ||
) | ||
if ("lr_annealingRate" not in config) or config[ | ||
"lr_annealingRate" | ||
] is None: | ||
scheduler = None | ||
else: | ||
scheduler = torch.optim.lr_scheduler.StepLR( | ||
optimizer, step_size=500, gamma=config["lr_annealingRate"] | ||
) | ||
|
||
return optimizer, scheduler, model | ||
Comment on lines
+146
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going beyond what should be in a Optimizer and LR scheduler should also be configurable in module_config, but with defaults available. |
Uh oh!
There was an error while loading. Please reload this page.