diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 225296240674a..b90c5dd4f96cd 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -130,7 +130,7 @@ def add_lightning_class_args( Args: lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. - subclass_mode: Whether allow any subclass of the given class. + subclass_mode: Whether to allow any subclass of the given class. required: Whether the argument group is required. Returns: @@ -145,8 +145,18 @@ def add_lightning_class_args( ): if issubclass(lightning_class, Callback): self.callback_keys.append(nested_key) + + # NEW LOGIC: If subclass_mode=False and required=False, only add if config provides this key + if not subclass_mode and not required: + config_path = f"{self.subcommand}.{nested_key}" if getattr(self, "subcommand", None) else nested_key + config = getattr(self, "config", {}) + if not any(k.startswith(config_path) for k in config): + # Skip adding class arguments + return [] + if subclass_mode: return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) + return self.add_class_arguments( lightning_class, nested_key, @@ -154,6 +164,7 @@ def add_lightning_class_args( instantiate=not issubclass(lightning_class, Trainer), sub_configs=True, ) + raise MisconfigurationException( f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " "Trainer, LightningModule, LightningDataModule, or Callback." diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 1b883dda0282a..9192b0e3ffc24 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1860,6 +1860,34 @@ def test_lightning_cli_args_and_sys_argv_warning(): LightningCLI(TestModel, run=False, args=["--model.foo=789"]) +def test_add_class_args_required_false_skips_addition(tmp_path): + from lightning.pytorch import callbacks, cli + + class FooCheckpoint(callbacks.ModelCheckpoint): + def __init__(self, dirpath, *args, **kwargs): + super().__init__(dirpath, *args, **kwargs) + + class SimpleModel: + def __init__(self): + pass + + class SimpleDataModule: + def __init__(self): + pass + + class FooCLI(cli.LightningCLI): + def __init__(self): + super().__init__( + model_class=SimpleModel, datamodule_class=SimpleDataModule, run=False, save_config_callback=None + ) + + def add_arguments_to_parser(self, parser): + parser.add_lightning_class_args(FooCheckpoint, "checkpoint", required=False) + + # Expectation: No error raised even though FooCheckpoint requires `dirpath` + FooCLI() + + def test_lightning_cli_jsonnet(cleandir): class MainModule(BoringModel): def __init__(self, main_param: int = 1):