Skip to content

Commit f33c099

Browse files
authored
Merge pull request #182 from StochasticTree/multi-threaded-sampling
Support for multi threading in various parts of the sampling algorithm
2 parents 0ded1e8 + a85cfcb commit f33c099

File tree

89 files changed

+6307
-755
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+6307
-755
lines changed

.gitattributes

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
* text=auto
2+
3+
*.c text eol=lf
4+
*.h text eol=lf
5+
*.cc text eol=lf
6+
*.cuh text eol=lf
7+
*.cu text eol=lf
8+
*.py text eol=lf
9+
*.txt text eol=lf
10+
*.R text eol=lf
11+
12+
*.sh text eol=lf
13+
*.ac text eol=lf
14+
15+
*.md text eol=lf
16+
*.csv text eol=lf

.github/workflows/cpp-test.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ jobs:
5656
shell: bash
5757
run: |
5858
echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT"
59+
60+
- name: Set up dependencies (linux clang)
61+
# Set up openMP on ubuntu-latest with clang compiler toolset (doesn't ship with the compiler suite like GCC and MSVC)
62+
if: matrix.os == 'ubuntu-latest' && matrix.cpp_compiler == 'clang++'
63+
run: |
64+
sudo apt-get update && sudo apt-get install -y libomp-dev
5965
6066
- name: Configure CMake
6167
# Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make.
@@ -69,6 +75,7 @@ jobs:
6975
-DUSE_SANITIZER=OFF
7076
-DBUILD_TEST=ON
7177
-DBUILD_DEBUG_TARGETS=OFF
78+
-DUSE_OPENMP=ON
7279
-S ${{ github.workspace }}
7380
7481
- name: Build

.github/workflows/pypi-wheels.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,37 @@ jobs:
2121
include:
2222
- os: ubuntu-latest
2323
cibw_archs: "x86_64"
24+
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
2425
- os: ubuntu-24.04-arm
2526
cibw_archs: "aarch64"
27+
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
2628
- os: windows-latest
2729
cibw_archs: "auto64"
30+
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
2831
- os: macos-13
2932
cibw_archs: "x86_64"
33+
macos_deployment_target: "13.0"
3034
- os: macos-14
3135
cibw_archs: "arm64"
36+
macos_deployment_target: "14.0"
3237

3338
steps:
3439
- uses: actions/checkout@v4
3540
with:
3641
submodules: 'recursive'
42+
43+
- name: Set up openmp (macos)
44+
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
45+
if: matrix.os == 'macos-13' || matrix.os == 'macos-14'
46+
run: |
47+
brew install libomp
3748
3849
- name: Build wheels
3950
uses: pypa/[email protected]
4051
env:
4152
CIBW_SKIP: "pp* *-musllinux_* *-win32"
4253
CIBW_ARCHS: ${{ matrix.cibw_archs }}
43-
MACOSX_DEPLOYMENT_TARGET: "10.13"
54+
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos_deployment_target }}
4455

4556
- uses: actions/upload-artifact@v4
4657
with:

.github/workflows/python-test.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ jobs:
3030
with:
3131
python-version: "3.10"
3232
cache: "pip"
33+
34+
- name: Set up openmp (macos)
35+
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
36+
if: matrix.os == 'macos-latest'
37+
run: |
38+
brew install libomp
3339
3440
- name: Install Package with Relevant Dependencies
3541
run: |

.github/workflows/r-test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ jobs:
2222
os: [ubuntu-latest, windows-latest, macos-latest]
2323

2424
steps:
25+
- name: Prevent conversion of line endings on Windows
26+
if: startsWith(matrix.os, 'windows')
27+
shell: pwsh
28+
run: git config --global core.autocrlf false
29+
2530
- uses: actions/checkout@v4
2631
with:
2732
submodules: 'recursive'

.github/workflows/regression-test.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
on:
2+
workflow_dispatch:
3+
4+
name: Running stochtree on benchmark datasets
5+
6+
jobs:
7+
stochtree_r:
8+
name: stochtree-r-bart-regression-test
9+
runs-on: ubuntu-latest
10+
11+
steps:
12+
- name: Checkout stochtree repo
13+
uses: actions/checkout@v4
14+
with:
15+
submodules: 'recursive'
16+
17+
- name: Setup pandoc
18+
uses: r-lib/actions/setup-pandoc@v2
19+
20+
- name: Setup R
21+
uses: r-lib/actions/setup-r@v2
22+
with:
23+
use-public-rspm: true
24+
25+
- name: Create a properly formatted version of the stochtree R package in a subfolder
26+
run: |
27+
Rscript cran-bootstrap.R 0 0 1
28+
29+
- name: Setup R dependencies
30+
uses: r-lib/actions/setup-r-dependencies@v2
31+
with:
32+
extra-packages: any::testthat, any::decor, local::stochtree_cran
33+
34+
- name: Create output directory for BART regression test results
35+
run: |
36+
mkdir -p tools/regression/bart/stochtree_bart_r_results
37+
mkdir -p tools/regression/bcf/stochtree_bcf_r_results
38+
39+
- name: Run the BART regression test benchmark suite
40+
run: |
41+
Rscript tools/regression/bart/regression_test_dispatch_bart.R
42+
Rscript tools/regression/bcf/regression_test_dispatch_bcf.R
43+
44+
- name: Collate and analyze regression test results
45+
run: |
46+
Rscript tools/regression/bart/regression_test_analysis_bart.R
47+
Rscript tools/regression/bcf/regression_test_analysis_bcf.R
48+
49+
- name: Store BART benchmark test results as an artifact of the run
50+
uses: actions/upload-artifact@v4
51+
with:
52+
name: stochtree-r-bart-summary
53+
path: tools/regression/bart/stochtree_bart_r_results/stochtree_bart_r_summary.csv
54+
55+
- name: Store BCF benchmark test results as an artifact of the run
56+
uses: actions/upload-artifact@v4
57+
with:
58+
name: stochtree-r-bcf-summary
59+
path: tools/regression/bcf/stochtree_bcf_r_results/stochtree_bcf_r_summary.csv

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ po/*~
7070
# RStudio Connect folder
7171
rsconnect/
7272

73+
# Configuration files generated by R build
74+
config.status
75+
config.log
76+
src/Makevars
77+
7378
## Python gitignore
7479

7580
# Byte-compiled / optimized / DLL files

CMakeLists.txt

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Build options
2-
option(USE_DEBUG "Set to ON for Debug mode" OFF)
2+
option(USE_DEBUG "Build with debug symbols and without optimization" OFF)
33
option(USE_SANITIZER "Use santizer flags" OFF)
4+
option(USE_OPENMP "Use openMP" ON)
5+
option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON)
46
option(BUILD_TEST "Build C++ tests with Google Test" OFF)
57
option(BUILD_DEBUG_TARGETS "Build Standalone C++ Programs for Debugging" ON)
68
option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
@@ -9,8 +11,8 @@ option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
911
set(CMAKE_CXX_STANDARD 17)
1012
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1113

12-
# Default to CMake 3.16
13-
cmake_minimum_required(VERSION 3.16)
14+
# Default to CMake 3.20
15+
cmake_minimum_required(VERSION 3.20)
1416

1517
# Define the project
1618
project(stochtree LANGUAGES C CXX)
@@ -34,6 +36,13 @@ if(USE_DEBUG)
3436
add_definitions(-DDEBUG)
3537
endif()
3638

39+
# Linker flags (empty by default, updated if using openmp)
40+
set(
41+
STOCHTREE_LINK_FLAGS
42+
""
43+
)
44+
45+
# Unix / MinGW compiler flags
3746
if(UNIX OR MINGW OR CYGWIN)
3847
set(
3948
CMAKE_CXX_FLAGS
@@ -42,11 +51,12 @@ if(UNIX OR MINGW OR CYGWIN)
4251
if(USE_DEBUG)
4352
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
4453
else()
45-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
54+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3")
4655
endif()
4756
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas -Wno-unused-private-field")
4857
endif()
4958

59+
# MSVC compiler flags
5060
if(MSVC)
5161
set(
5262
variables
@@ -72,6 +82,33 @@ else()
7282
endif()
7383
endif()
7484

85+
# OpenMP
86+
if(USE_OPENMP)
87+
add_definitions(-DSTOCHTREE_OPENMP_AVAILABLE)
88+
if(APPLE)
89+
find_package(OpenMP)
90+
if(NOT OpenMP_FOUND)
91+
if(USE_HOMEBREW_FALLBACK)
92+
execute_process(COMMAND brew --prefix libomp
93+
OUTPUT_VARIABLE HOMEBREW_LIBOMP_PREFIX
94+
OUTPUT_STRIP_TRAILING_WHITESPACE)
95+
set(OpenMP_C_FLAGS "-Xclang -fopenmp -I${HOMEBREW_LIBOMP_PREFIX}/include")
96+
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp -I${HOMEBREW_LIBOMP_PREFIX}/include")
97+
set(OpenMP_C_INCLUDE_DIR "")
98+
set(OpenMP_CXX_INCLUDE_DIR "")
99+
set(OpenMP_C_LIB_NAMES libomp)
100+
set(OpenMP_CXX_LIB_NAMES libomp)
101+
set(OpenMP_libomp_LIBRARY ${HOMEBREW_LIBOMP_PREFIX}/lib/libomp.dylib)
102+
endif()
103+
find_package(OpenMP REQUIRED)
104+
endif()
105+
else()
106+
find_package(OpenMP REQUIRED)
107+
endif()
108+
# Update flags with openmp
109+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
110+
endif()
111+
75112
# Header file directory
76113
set(StochTree_HEADER_DIR ${PROJECT_SOURCE_DIR}/include)
77114

@@ -80,6 +117,8 @@ set(BOOSTMATH_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/boost_math/include)
80117

81118
# Eigen header file directory
82119
set(EIGEN_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/eigen)
120+
add_definitions(-DEIGEN_MPL2_ONLY)
121+
add_definitions(-DEIGEN_DONT_PARALLELIZE)
83122

84123
# fast_double_parser header file directory
85124
set(FAST_DOUBLE_PARSER_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/fast_double_parser/include)
@@ -109,10 +148,11 @@ file(
109148
add_library(stochtree_objs OBJECT ${SOURCES})
110149

111150
# Include the headers in the source library
112-
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
113-
114-
if(APPLE)
115-
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
151+
if(USE_OPENMP)
152+
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
153+
target_link_libraries(stochtree_objs PRIVATE ${OpenMP_libomp_LIBRARY})
154+
else()
155+
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
116156
endif()
117157

118158
# Python shared library
@@ -122,8 +162,13 @@ if (BUILD_PYTHON)
122162
pybind11_add_module(stochtree_cpp src/py_stochtree.cpp)
123163

124164
# Link to C++ source and headers
125-
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
126-
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
165+
if(USE_OPENMP)
166+
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
167+
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
168+
else()
169+
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
170+
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
171+
endif()
127172

128173
# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a
129174
# define (VERSION_INFO) here.
@@ -154,8 +199,13 @@ if(BUILD_TEST)
154199
file(GLOB CPP_TEST_SOURCES test/cpp/*.cpp)
155200
add_executable(teststochtree ${CPP_TEST_SOURCES})
156201
set(STOCHTREE_TEST_HEADER_DIR ${PROJECT_SOURCE_DIR}/test/cpp)
157-
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
158-
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
202+
if(USE_OPENMP)
203+
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
204+
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main ${OpenMP_libomp_LIBRARY})
205+
else()
206+
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
207+
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
208+
endif()
159209
gtest_discover_tests(teststochtree)
160210
endif()
161211

@@ -164,7 +214,12 @@ if(BUILD_DEBUG_TARGETS)
164214
# Build test suite
165215
add_executable(debugstochtree debug/api_debug.cpp)
166216
set(StochTree_DEBUG_HEADER_DIR ${PROJECT_SOURCE_DIR}/debug)
167-
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
168-
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
217+
if(USE_OPENMP)
218+
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
219+
target_link_libraries(debugstochtree PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
220+
else()
221+
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
222+
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
223+
endif()
169224
endif()
170225

R/bart.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
5252
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
5353
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
54+
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
5455
#'
5556
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
5657
#'
@@ -130,7 +131,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
130131
rfx_working_parameter_prior_cov = NULL,
131132
rfx_group_parameter_prior_cov = NULL,
132133
rfx_variance_prior_shape = 1,
133-
rfx_variance_prior_scale = 1
134+
rfx_variance_prior_scale = 1,
135+
num_threads = -1
134136
)
135137
general_params_updated <- preprocessParams(
136138
general_params_default, general_params
@@ -186,6 +188,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
186188
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
187189
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
188190
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
191+
num_threads <- general_params_updated$num_threads
189192

190193
# 2. Mean forest parameters
191194
num_trees_mean <- mean_forest_params_updated$num_trees
@@ -795,7 +798,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
795798
forest_model_mean$sample_one_iteration(
796799
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
797800
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
798-
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
801+
global_model_config = global_model_config, num_threads = num_threads,
802+
keep_forest = keep_sample, gfr = TRUE
799803
)
800804

801805
# Cache train set predictions since they are already computed during sampling
@@ -1272,15 +1276,23 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
12721276
result <- list()
12731277
if ((object$model_params$has_rfx) || (object$model_params$include_mean_forest)) {
12741278
result[["y_hat"]] = y_hat
1279+
} else {
1280+
result[["y_hat"]] <- NULL
12751281
}
12761282
if (object$model_params$include_mean_forest) {
12771283
result[["mean_forest_predictions"]] = mean_forest_predictions
1284+
} else {
1285+
result[["mean_forest_predictions"]] <- NULL
12781286
}
12791287
if (object$model_params$has_rfx) {
12801288
result[["rfx_predictions"]] = rfx_predictions
1289+
} else {
1290+
result[["rfx_predictions"]] <- NULL
12811291
}
12821292
if (object$model_params$include_variance_forest) {
12831293
result[["variance_forest_predictions"]] = variance_forest_predictions
1294+
} else {
1295+
result[["variance_forest_predictions"]] <- NULL
12841296
}
12851297
return(result)
12861298
}

0 commit comments

Comments
 (0)