Skip to content

Add MLJ compliant docstrings #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 18, 2025
380 changes: 378 additions & 2 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ MLJModelInterface.metadata_model(
input=MLJModelInterface.Table(MLJModelInterface.Continuous),
target=AbstractVector{<:MLJModelInterface.Finite},
weights=true,
descr="Microsoft LightGBM FFI wrapper: Classifier",
human_name="LightGBM classifier",
)

MLJModelInterface.metadata_model(
Expand All @@ -406,7 +406,383 @@ MLJModelInterface.metadata_model(
input=MLJModelInterface.Table(MLJModelInterface.Continuous),
target=AbstractVector{MLJModelInterface.Continuous},
weights=true,
descr="Microsoft LightGBM FFI wrapper: Regressor",
human_name="LightGBM regressor",
)

"""
$(MLJModelInterface.doc_header(LGBMRegressor))

LightGBM, short for light gradient-boosting machine, is a
framework for gradient boosting based on decision tree algorithms and used for
classification, regression and other machine learning tasks, with a focus on
performance and scalability. This model in particular is used for various types of
regression tasks.

# Training data

In MLJ or MLJBase, bind an instance `model` to data with

mach = machine(model, X, y)

Here:

- `X` is any table of input features (eg, a `DataFrame`) whose columns are of
scitype `Continuous`; check the column scitypes with `schema(X)`; alternatively,
`X` is any `AbstractMatrix` with `Continuous` elements; check the scitype with
`scitype(X)`.
- y is a vector of targets whose items are of scitype `Continuous`. Check the
scitype with `scitype(y)`.

Train the machine using `fit!(mach, rows=...)`.

# Operations

- `predict(mach, Xnew)`: return predictions of the target given new features
`Xnew`, which should have the same scitype as `X` above.

# Hyper-parameters

- `boosting::String = "gbdt"`: Which boosting algorithm to use. One of:
- gbdt: traditional gradient boosting
- rf: random forest
- dart: dropout additive regression trees
- goss: gradient one side sampling
- `num_iterations::Int = 10`: Number of iterations to run the boosting algorithm.
- `learning_rate::Float64 = 0.1`: The update or shrinkage rate. In `dart`
boosting, also affects the normalization weights of dropped trees.
- `num_leaves::Int = 31`: The maximum number of leaves in one tree.
- `max_depth::Int = -1`: The limit on the maximum depth of a tree. Used to reduce
overfitting. Set to `≤0` for unlimited depth
- `tree_learner::String = "serial"`: The tree learning mode. One of:
- "serial":: Single machine tree learner.
- "feature": feature parallel tree learner.
- "data": data parallel tree learner
- "voting": voting parallel tree learner. see the [LightGBM distributed
learning
guide](https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html)
for details
- `histogram_pool_size::Float64 = -1.0`: Max size in MB for the historical
histogram. Set to `≤0` for an unlimited size.
- `min_data_in_leaf::Int = 20`: Minimal number of data in one leaf. Can be used to
deal with over-fitting.
- `min_sum_hessian_in_leaf::Float64 = 1e-3`: Minimal sum hessian in one leaf. Like
`min_data_in_leaf`, it can be used to deal with over-fitting.
- `max_delta_step::Float64 = 0.0`: Used to limit the max output of tree leaves.
The final maximum amount of leaves is `max_delta_step * learning_rate`. A value
less than 0 means no limit on the max output.
- `lambda_l1::Float64 = 0.0`: L1 regularization.
- `lambda_l2::Float64 = 0.0`: L2 regularization.
- `min_gain_to_split::Float64 = 0.0`: The minimal gain required to perform a
split. Can be used to speed up training.
- `feature_fraction::Float64 = 1.0`: The fraction of features to select before
fitting a tree. Can be used to speed up training and reduce over-fitting.
- `feature_fraction_bynode::Float64 = 1.0`: The fraction of features to select for
each tree node. Can be used to reduce over-fitting.
- `feature_fraction_seed::Int = 2`: Random seed to use for the gesture fraction
- `bagging_fraction::Float64 = 1.0`: The fraction of samples to use before
fitting a tree. Can be used to speed up training and reduce over-fitting.
- `bagging_freq::Int = 0`: The frequency to perform bagging at. At frequency `k`,
every `k` samples select `bagging_fraction` of the data and use that data for
the next `k` iterations.
- `bagging_seed::Int = 3`: The random seed to use for bagging.
- `early_stopping_round::Int = 0`: Will stop training if a validation metric does
not improve over `early_stopping_round` rounds.
- `extra_trees::Bool = false`: Use extremely randomized trees. If true, will only
check one randomly chosen threshold before splitting. Can be used to speed up
training and reduce over-fitting.
- `extra_seed::Int = 6`: The random seed to use for `extra_trees`.
- `max_bin::Int = 255`: Number of bins feature values will be bucketed in. Smaller
values may reduce training accuracy and help alleviate over-fitting.
- `bin_construct_sample_cnt = 200000`: Number of samples to use to construct bins.
Larger values will give better results but may increase data loading time.
- `init_score::String = ""`: The initial score to try and correct in the first
boosting iteration.
- `drop_rate::Float64 = 0.1`: The dropout rate for `dart`.
- `max_drop::Int = 50`: The maximum number of trees to drop in `dart`.
- `skip_drop:: Float64 = 0.5`: Probability of skipping dropout in `dart`.
- `xgboost_dart_mode::Bool`: Set to true if you want to use xgboost dart mode in
dart.
- `uniform_drop::Bool`: Set to true if you want to use uniform dropout in `dart`.
- `drop_seed::Int = 4`: Random seed for `dart` dropout.
- `top_rate::Float64 = 0.2`: The retain ratio of large gradient data in `goss`.
- `other_rate::Float64 = 0.1`: The retain ratio of large gradient data in `goss`.
- `min_data_per_group::Int = 100`: Minimal amount of data per categorical group.
- `max_cat_threshold::Int = 32`: Limits the number of split points considered for
categorical features.
- `cat_l2::Float64 = 10.0`: L2 regularization for categorical splits
- `cat_smooth::Float64 = 10.0`: Reduces noise in categorical features,
particularly useful when there are categories with little data
- `objective::String = "regression"`: The objective function to use. One of:
- "regression": L2 loss or mse.
- "regression_l1": L1 loss or mae.
- "huber": Huber loss.
- "fair": Fair loss.
- "poisson": poisson regression.
- "quantile": Quantile regression.
- "mape": MAPE (mean mean_absolute_percentage_error) loss.
- "gamma": Gamma regression with log-link.
- "tweedie": Tweedie regression with log-link.
- `categorical_feature::Vector{Int} = Vector{Int}()`: Used to specify the
categorical features. Items in the vector are column indices representing which
features should be interpreted as categorical.
- `data_random_seed::Int = 1`: Random seed used when constructing histogram bins.
- `is_sparse::Bool = true`: Enable/disable sparse optimization.
- `is_unbalance::Bool = false`: Set to true if training data is unbalanced.
- `boost_from_average::Bool = true`: Adjusts the initial score to the mean of
labels for faster convergence.
- `use_missing::Bool = true`: Whether or not to handle missing values.
- `feature_pre_filter::Bool = true`: Whether or not to ignore unsplittable
features.
- `alpha::Float64 = 0.9`: Parameter used for huber and quantile regression.
- `metric::Vector{String} = ["l2"]`: Metric(s) to be used when evaluating on
evaluation set. For detailed information, see [the official
documentation](https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric-parameters)
- `metric_freq::Int = 1`: The frequency to run metric evaluation at.
- `is_training_metric::Bool = false`: Set to `true` to output metric result on
training dataset.
- `ndcg_at::Vector{Int} = Vector{Int}([1, 2, 3, 4, 5])`: Evaluation positions for
ndcg and map metrics.
- `num_machines::Int = 1`: Number of machines to use when doing distributed
learning.
- `num_threads::Int = 0`: Number of threads to use.
- `local_listen_port::Int = 12400`: TCP listen port.
- `time_out::Int = 120`: Socket timeout.
- `machine_list_file::String = ""`: Path of files that lists the machines used for
distributed learning.
- `save_binary::Bool = false`: Whether or not to save the dataset to a binary file
- `device_type::String = "cpu"`: The type of device being used. One of `cpu` or
`gpu`
- `force_col_wise::Bool = false`: Force column wise histogram building. Only
applicable on cpu.
- `force_row_wise::Bool = false`: Force row wise histogram building. Only
applicable on cpu.
- `truncate_booster::Bool = true`: Whether or not to truncate the booster.

# Fitted parameters

The fields of `fitted_params(mach)` are:

- `fitresult`: Fitted model information, contains a `LGBMRegression` object, an
empty vector, and the regressor with all its parameters

# Report

The fields of `report(mach)` are:

- `training_metrics`: A dictionary containing all training metrics.
- `importance`: A `namedtuple` containing:
- `gain`: The total gain of each split used by the model
- `split`: The number of times each feature is used by the model.

# Examples

```julia

using DataFrames
using MLJ

# load the model (make sure to Pkg.add LightGBM to the environment)
LGBMRegressor = @load LGBMRegressor

X, y = @load_boston # a table and a vector
X = DataFrame(X)
train, test = partition(collect(eachindex(y)), 0.70, shuffle=true)

first(X, 3)
lgb = LGBMRegressor() #initialised a model with default params
mach = machine(lgb, X[train, :], y[train]) |> fit!

predict(mach, X[test, :]) ```

"""
LGBMRegressor


"""
$(MLJModelInterface.doc_header(LGBMClassifier))

`LightGBM, short for light gradient-boosting machine, is a
framework for gradient boosting based on decision tree algorithms and used for
classification and other machine learning tasks, with a focus on
performance and scalability. This model in particular is used for various types of
classification tasks.

# Training data In MLJ or MLJBase, bind an instance `model` to data with

mach = machine(model, X, y)

Here:

- `X` is any table of input features (eg, a `DataFrame`) whose columns are of
scitype `Continuous`; check the column scitypes with `schema(X)`; alternatively,
`X` is any `AbstractMatrix` with `Continuous` elements; check the scitype with
`scitype(X)`.
- y is a vector of targets whose items are of scitype `Continuous`. Check the
scitype with scitype(y).

Train the machine using `fit!(mach, rows=...)`.

# Operations

- `predict(mach, Xnew)`: return predictions of the target given new features
`Xnew`, which should have the same scitype as `X` above.

# Hyper-parameters

- `boosting::String = "gbdt"`: Which boosting algorithm to use. One of:
- gbdt: traditional gradient boosting
- rf: random forest
- dart: dropout additive regression trees
- goss: gradient one side sampling
- `num_iterations::Int = 10`: Number of iterations to run the boosting algorithm.
- `learning_rate::Float64 = 0.1`: The update or shrinkage rate. In `dart`
boosting, also affects the normalization weights of dropped trees.
- `num_leaves::Int = 31`: The maximum number of leaves in one tree.
- `max_depth::Int = -1`: The limit on the maximum depth of a tree.
- `tree_learner::String = "serial"`: The tree learning mode. One of:
- serial: Single machine tree learner.
- feature: feature parallel tree learner.
- data: data parallel tree learner
- voting: voting parallel tree learner. see the [LightGBM distributed learning
guide](https://lightgbm.readthedocs.io/en/latest/Parallel-Learning-Guide.html)
for details
- `histogram_pool_size::Float64 = -1.0`: Max size in MB for the historical
histogram.
- `min_data_in_leaf::Int = 20`: Minimal number of data in one leaf. Can be used to
deal with over-fitting.
- `min_sum_hessian_in_leaf::Float64 = 1e-3`: Minimal sum hessian in one leaf. Like
min_data_in_leaf, it can be used to deal with over-fitting.
- `max_delta_step::Float64 = 0.0`: Used to limit the max output of tree leaves.
The final maximum amount of leaves is `max_delta_step * learning_rate`.
- `lambda_l1::Float64 = 0.0`: L1 regularization.
- `lambda_l2::Float64 = 0.0`: L2 regularization.
- `min_gain_to_split::Float64 = 0.0`: The minimal gain required to perform a
split. Can be used to speed up training.
- `feature_fraction::Float64 = 1.0`: The fraction of features to select before
fitting a tree. Can be used to speed up training and reduce over-fitting.
- `feature_fraction_bynode::Float64 = 1.0`: The fraction of features to select for
each tree node. Can be used to reduce over-fitting.
- `feature_fraction_seed::Int = 2`: Random seed to use for the gesture fraction
- `bagging_fraction::Float64 = 1.0`: The fraction of samples to use before
fitting a tree. Can be used to speed up training and reduce over-fitting.
- `bagging_freq::Int = 0`: The frequency to perform bagging at. At frequency `k`,
every `k` samples select `bagging_fraction` of the data and use that data for
the next `k` iterations.
- `bagging_seed::Int = 3`: The random seed to use for bagging.
- `early_stopping_round::Int = 0`: Will stop training if a validation metric does
not improve over `early_stopping_round` rounds.
- `extra_trees::Bool = false`: Use extremely randomized trees. If true, will only
check one randomly chosen threshold before splitting. Can be used to speed up
training and reduce over-fitting.
- `extra_seed::Int = 6`: The random seed to use for `extra_trees`.
- `max_bin::Int = 255`: Number of bins feature values will be bucketed in. Smaller
values may reduce training accuracy and help alleviate over-fitting.
- `bin_construct_sample_cnt = 200000`: Number of samples to use to construct bins.
Larger values will give better results but may increase data loading time.
- `init_score::String = ""`: The initial score to try and correct in the first
boosting iteration.
- `drop_rate::Float64 = 0.1`: The dropout rate for `dart`.
- `max_drop::Int = 50`: The maximum number of trees to drop in `dart`.
- `skip_drop:: Float64 = 0.5`: Probability of skipping dropout in `dart`.
- `xgboost_dart_mode::Bool`: Set to true if you want to use xgboost dart mode in
dart.
- `uniform_drop::Bool`: Set to true if you want to use uniform dropout in `dart`.
- `drop_seed::Int = 4`: Random seed for `dart` dropout.
- `top_rate::Float64 = 0.2`: The retain ratio of large gradient data in `goss`.
- `other_rate::Float64 = 0.1`: The retain ratio of large gradient data in `goss`.
- `min_data_per_group::Int = 100`: Minimal amount of data per categorical group.
- `max_cat_threshold::Int = 32`: Limits the number of split points considered for
categorical features.
- `cat_l2::Float64 = 10.0`: L2 regularization for categorical splits
- `cat_smooth::Float64 = 10.0`: Reduces noise in categorical features,
particularly useful when there are categories with little data
- `objective::String = "multiclass"`: The objective function to use. One of:
- binary: Binary log loss classification.
- multiclass: Softmax classification.
- multiclassova: One verse all multiclass classification. `num_class` should
be set as well
- cross_entropy: Cross-entropy objective function.
- cross_entropy_lambda: Alternative parametrized form of the cross-entropy
objective function.
- lambdarank: The lambdarank objective function, for use in ranking
applications.
- rank_xendcg: The XE_NDCG_MART ranking objective function. Faster than
lambdarank with same peroformance.
- `categorical_feature::Vector{Int} = Vector{Int}()`: Used to specify the
categorical features. Items in the vector are column indices representing which
features should be interpreted as categorical.
- `data_random_seed::Int = 1`: Random seed used when constructing histogram bins.
- `is_sparse::Bool = true`: Enable/disable sparse optimization.
- `is_unbalance::Bool = false`: Set to true if training data is unbalanced.
- `boost_from_average::Bool = true`: Adjusts the initial score to the mean of
labels for faster convergence.
- `use_missing::Bool = true`: Whether or not to handle missing values.
- `feature_pre_filter::Bool = true`: Whether or not to ignore unsplittable
features.
- `alpha::Float64 = 0.9`: Parameter used for huber and quantile regression.
- `metric::Vector{String} = ["none"]`: Metric(s) to be used when evaluating on
evaluation set. For detailed information, see [the official
documentation](https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric-parameters)
- `metric_freq::Int = 1`: The frequency to run metric evaluation at.
- `is_training_metric::Bool = false`: Set true to output metric result on training
dataset.
- `ndcg_at::Vector{Int} = Vector{Int}([1, 2, 3, 4, 5])`: Evaluation positions for
ndcg and map metrics.
- `num_machines::Int = 1`: Number of machines to use when doing distributed
learning.
- `num_threads::Int = 0`: Number of threads to use.
- `local_listen_port::Int = 12400`: TCP listen port.
- `time_out::Int = 120`: Socket timeout.
- `machine_list_file::String = ""`: Path of files that lists the machines used for
distributed learning.
- `save_binary::Bool = false`: Whether or not to save the dataset to a binary file
- `device_type::String = "cpu"`: The type of device being used. One of `cpu` or
`gpu`
- `force_col_wise::Bool = false`: Force column wise histogram building. Only
applicable on cpu.
- `force_row_wise::Bool = false`: Force row wise histogram building. Only
applicable on cpu.
- `truncate_booster::Bool = true`: Whether or not to truncate the booster.

# Fitted parameters

The fields of `fitted_params(mach)` are:

- `fitresult`: Fitted model information, contains a `LGBMClassification` object, a
`CategoricalArray` of the input class names, and the classifier with all its
parameters

# Report

The fields of `report(mach)` are:

- `training_metrics`: A dictionary containing all training metrics.
- `importance`: A `namedtuple` containing:
- `gain`: The total gain of each split used by the model
- `split`: The number of times each feature is used by the model.


# Examples

```julia

using DataFrames
using MLJ

# load the model (make sure to Pkg.add LightGBM to the environment)
LGBMClassifier = @load LGBMClassifier

X, y = @load_iris
X = DataFrame(X)
train, test = partition(collect(eachindex(y)), 0.70, shuffle=true)

first(X, 3)
lgb = LGBMClassifier() #initialised a model with default params
mach = machine(lgb, X[train, :], y[train]) |> fit!

predict(mach, X[test, :])```

"""
LGBMClassifier

end # module
Loading