Skip to content

Training script #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d3df1e9
[ADD] sgm training code, loading code for DL2DV and initial config
Apr 4, 2025
f54d092
[ADD] SevaAutoenoderKL (wrapper to handle frames with image autoencoder)
Apr 7, 2025
907ac4a
[ADD] SevaWrapper
Apr 7, 2025
17d4df6
[ADD] correct naming and remove print
nviolante25 Apr 7, 2025
dd9fe71
[FIX] correct output in dataset
Apr 8, 2025
4245b41
Merge branch 'training' of github.com:nviolante25/stable-virtual-came…
Apr 8, 2025
efc7732
[ADD] conditioning in plucker and binary mask, mask out target latents
nviolante25 Apr 8, 2025
6273420
[FIX] remove input frames from return dict
Apr 8, 2025
fe6bcf2
[FIX] correct shapes in wrapper, and dataset (needs to be checked) pa…
Apr 9, 2025
31f8c33
[ADD] loss weighting
nviolante25 Apr 10, 2025
d3c6532
[ADD] average CLIP embedding of input frames
nviolante25 Apr 10, 2025
4b7675f
[ADD] better names in loss weight
nviolante25 Apr 10, 2025
5f43f47
[FIX] don't pop the mask
nviolante25 Apr 14, 2025
9325dea
[FIX] repeat cross attention
Apr 14, 2025
f79e2e7
[ADD] precompute latents from SD 2.1
nviolante25 Apr 14, 2025
3f42ad7
[ADD] replace condition and correct camera normalization
nviolante25 Apr 14, 2025
36eb7c1
Merge remote-tracking branch 'refs/remotes/origin/training' into trai…
nviolante25 Apr 14, 2025
618dd5e
[FIX] replace in conditioner
nviolante25 Apr 14, 2025
51920b8
[FIX] correct dimension to retrieve conditionings
nviolante25 Apr 14, 2025
57b12f8
[ADD] update config
nviolante25 Apr 14, 2025
9eff550
[ADD] load SEVA checkpoint
nviolante25 Apr 15, 2025
ebff52a
[FIX] upper idx when sampling adjacent frames
Apr 15, 2025
1b5dcaf
[FIX] verify that colmap poses correspond to the images, there couldb…
Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/rec_loss

loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True

regularization_weights:
kl_loss: 1.0

regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: none
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4]
num_res_blocks: 4
attn_resolutions: []
dropout: 0.0

decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}

data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000

decoders:
- pil

postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width

loader:
batch_size: 8
num_workers: 4


lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000

image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True

trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000
105 changes: 105 additions & 0 deletions configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/loss/rec
disc_start_iter: 0

encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla-xformers
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0

decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}

regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True

regularization_weights:
kl_loss: 1.0

data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000

decoders:
- pil

postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width

loader:
batch_size: 8
num_workers: 4


lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000

image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True

trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000
185 changes: 185 additions & 0 deletions configs/example_training/imagenet-f8_cond.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
log_keys:
- cls

scheduler_config:
target: sgm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [1.]

denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000

scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 256
attention_resolutions: [1, 2, 4]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
num_classes: sequential
adm_in_channels: 1024
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers

conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
add_sequence_dim: True
embed_dim: 1024
n_classes: 1000

- is_trainable: False
ucg_rate: 0.2
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256

- is_trainable: False
input_key: crop_coords_top_left
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256

first_stage_config:
target: sgm.models.autoencoder.AutoencoderKL
params:
ckpt_path:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
num_idx: 1000

discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50

discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 5.0

data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
# USER: adapt this path the root of your custom dataset
- DATA_PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM

decoders:
- pil

postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg # USER: you might wanna adapt this for your custom dataset
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler

- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height # USER: you might wanna adapt this for your custom dataset
w_key: width # USER: you might wanna adapt this for your custom dataset

loader:
batch_size: 64
num_workers: 6

lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000

image_logger:
target: main.ImageLogger
params:
disabled: False
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 8
n_rows: 2

trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 1000
Loading