Skip to content

Commit 87c6612

Browse files
committed
Llama FSDP training with FP8
1 parent 48c339f commit 87c6612

File tree

11 files changed

+1597
-0
lines changed

11 files changed

+1597
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
checkpoints
2+
slurm-*.out
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: MIT-0
3+
4+
FROM nvcr.io/nvidia/pytorch:24.04-py3
5+
ENV DEBIAN_FRONTEND=noninteractive
6+
7+
# The three must-be-built packages.
8+
# Efa-installer>=1.29.1 required for nccl>=2.19.0 to avoid libfabric NCCL error.
9+
ENV EFA_INSTALLER_VERSION=1.30.0
10+
ENV AWS_OFI_NCCL_VERSION=1.8.1-aws
11+
ENV NCCL_TESTS_VERSION=master
12+
13+
## Uncomment below when this Dockerfile builds a container image with efa-installer<1.29.1 and
14+
# nccl>=2.19.0. See https://github.com/aws-samples/awsome-distributed-training/tree/main/1.architectures/efa-cheatsheet.md
15+
#ENV FI_EFA_SET_CUDA_SYNC_MEMOPS=0
16+
17+
RUN apt-get update -y
18+
RUN apt-get remove -y --allow-change-held-packages \
19+
libmlx5-1 ibverbs-utils libibverbs-dev libibverbs1
20+
21+
# We noticed that since 23.09, we can't just delete the whole /opt/hpcx/, otherwise `import torch`
22+
# complains about missing libuc?.so.
23+
RUN rm -rf /opt/hpcx/ompi \
24+
&& rm -rf /usr/local/mpi \
25+
&& rm -rf /opt/hpcx/nccl_rdma_sharp_plugin \
26+
&& ldconfig
27+
ENV OPAL_PREFIX=
28+
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
29+
git \
30+
gcc \
31+
vim \
32+
kmod \
33+
openssh-client \
34+
openssh-server \
35+
build-essential \
36+
curl \
37+
autoconf \
38+
libtool \
39+
gdb \
40+
automake \
41+
cmake \
42+
apt-utils \
43+
libhwloc-dev \
44+
aptitude && \
45+
DEBIAN_FRONTEND=noninteractive apt autoremove -y
46+
47+
# EFA
48+
RUN apt-get update && \
49+
cd /tmp && \
50+
curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \
51+
tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \
52+
cd aws-efa-installer && \
53+
# ONLY add `--skip-kmod`, `--no-verify` and `--skip-limit-conf` flags to container image.
54+
# Those three flags must NOT be used on the host.
55+
#
56+
# Explanations:
57+
# - to build EFA in the Dockerfile, we added --skip-kmod and --no-verify. Without these flags,
58+
# the Dockerfile will fail to build. If installing EFA on the host and not in a container,
59+
# please remove these flags.
60+
# - The --skip-limit-conf can be retained in Dockerfile, but it's redundant as the host already
61+
# has these limits set by efa_installer.
62+
./efa_installer.sh -y -g -d --skip-kmod --no-verify --skip-limit-conf && \
63+
ldconfig && \
64+
rm -rf /tmp/aws-efa-installer /var/lib/apt/lists/*
65+
ENV LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH
66+
ENV PATH=/opt/amazon/efa/bin:/opt/amazon/openmpi/bin:$PATH
67+
68+
69+
####################################################################################################
70+
# [CUSTOM_NCCL_OPTION_1] Uncomment below stanza to install another NCCL version using the official
71+
# binaries.
72+
#
73+
# NCCL EFA plugin (aws-ofi-nccl) depends on mpi, hence we must rebuild openmpi before building the
74+
# aws-ofi-ccnl.
75+
####################################################################################################
76+
#ENV NCCL_VERSION=2.19.3-1
77+
#RUN cd /opt && \
78+
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb && \
79+
# dpkg -i cuda-keyring_1.0-1_all.deb && \
80+
# apt update && \
81+
# apt install -y libnccl2==${NCCL_VERSION} libnccl-dev==${NCCL_VERSION} && \
82+
# echo NCCL_SOCKET_IFNAME=^docker0,lo >> /etc/nccl.conf
83+
84+
85+
####################################################################################################
86+
# [CUSTOM_NCCL_OPTION_2] Install NCCL from source to the same location as the built-in ones. The
87+
# benefits of installing to the same location as the built-in version are:
88+
#
89+
# 1. There's only ever a single libnccl version offered by this image, preventing application from
90+
# mistakenly chooses a wrong version.
91+
# 2. No longer needing extra settings for LD_LIBRARY_PATH or LD_PRELOAD.
92+
#
93+
# NCCL EFA plugin (aws-ofi-nccl) depends on mpi, hence we must rebuild openmpi before building the
94+
# aws-ofi-ccnl.
95+
####################################################################################################
96+
ENV NCCL_VERSION=2.19.3-1
97+
RUN apt-get remove -y libnccl2 libnccl-dev \
98+
&& cd /tmp \
99+
&& git clone https://github.com/NVIDIA/nccl.git -b v${NCCL_VERSION} \
100+
&& cd nccl \
101+
&& make -j src.build BUILDDIR=/usr \
102+
# Build for p4 & p5.
103+
NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90, -gencode=arch=compute_80,code=sm_80" \
104+
&& rm -rf /tmp/nccl \
105+
&& echo NCCL_SOCKET_IFNAME=^docker0,lo >> /etc/nccl.conf
106+
107+
108+
####################################################################################################
109+
# Rebuild OpenMPI with custom PMIX version. E.g., to match what host's Slurm is built with (see
110+
# /opt/pmix/ on host, or run pmix_info on host).
111+
#
112+
# May be needed on rare occassions when `srun --mpi=pmix --container-image=... <mpi_application>`
113+
# mysteriously crashes.
114+
#
115+
# NCCL EFA plugin (aws-ofi-nccl) depends on mpi, hence we must rebuild openmpi before building the
116+
# aws-ofi-ccnl.
117+
####################################################################################################
118+
ENV OPEN_MPI_PATH=/opt/amazon/openmpi
119+
120+
# OpenMPI build script claims PMIX_VERSION, and complains if we use it.
121+
ENV CUSTOM_PMIX_VERSION=4.2.6
122+
RUN apt-get update && apt-get install -y libevent-dev \
123+
&& cd /tmp \
124+
&& wget https://github.com/openpmix/openpmix/releases/download/v${CUSTOM_PMIX_VERSION}/pmix-${CUSTOM_PMIX_VERSION}.tar.gz \
125+
&& tar -xzf pmix-${CUSTOM_PMIX_VERSION}.tar.gz \
126+
&& rm pmix-${CUSTOM_PMIX_VERSION}.tar.gz \
127+
&& cd pmix-${CUSTOM_PMIX_VERSION}/ \
128+
&& ./autogen.pl \
129+
&& ./configure --prefix=/opt/pmix \
130+
&& make -j \
131+
&& make install \
132+
&& echo /opt/pmix/lib > /etc/ld.so.conf.d/pmix.conf \
133+
&& ldconfig \
134+
&& cd / \
135+
&& rm -fr /tmp/pmix-${CUSTOM_PMIX_VERSION}/
136+
# To silence this runtime error message:
137+
# [p4de-st-p4de-2:110912] PMIX ERROR: ERROR in file gds_ds12_lock_pthread.c at line 168
138+
ENV PMIX_GDS_MODULE=^ds12 \
139+
PMIX_MCA_gds=^ds12
140+
141+
# Rebuild openmpi with DLC style (which it remarks as "without libfabric"), with the above pmix.
142+
ENV OMPI_VERSION=4.1.6
143+
RUN rm -fr ${OPEN_MPI_PATH} \
144+
&& mkdir /tmp/openmpi \
145+
&& cd /tmp/openmpi \
146+
&& wget --quiet https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OMPI_VERSION}.tar.gz \
147+
&& tar zxf openmpi-${OMPI_VERSION}.tar.gz \
148+
&& rm openmpi-${OMPI_VERSION}.tar.gz \
149+
&& cd openmpi-${OMPI_VERSION} \
150+
&& ./configure --enable-orterun-prefix-by-default --prefix=$OPEN_MPI_PATH --with-cuda=${CUDA_HOME} --with-slurm --with-pmix=/opt/pmix \
151+
&& make -j $(nproc) all \
152+
&& make install \
153+
&& ldconfig \
154+
&& cd / \
155+
&& rm -rf /tmp/openmpi \
156+
&& ompi_info --parsable --all | grep mpi_built_with_cuda_support:value \
157+
# Verify pmix from /opt/pmix/
158+
&& ldd /opt/amazon/openmpi/lib/openmpi/mca_pmix_ext3x.so | grep '/opt/pmix/lib/libpmix.so.* ' > /opt/amazon/openmpi-pmix.txt
159+
####################################################################################################
160+
161+
162+
# NCCL EFA Plugin
163+
RUN mkdir -p /tmp && \
164+
cd /tmp && \
165+
curl -LO https://github.com/aws/aws-ofi-nccl/archive/refs/tags/v${AWS_OFI_NCCL_VERSION}.tar.gz && \
166+
tar -xzf /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \
167+
rm /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \
168+
mv aws-ofi-nccl-${AWS_OFI_NCCL_VERSION} aws-ofi-nccl && \
169+
cd /tmp/aws-ofi-nccl && \
170+
./autogen.sh && \
171+
./configure --prefix=/opt/amazon/efa \
172+
--with-libfabric=/opt/amazon/efa \
173+
--with-cuda=/usr/local/cuda \
174+
--enable-platform-aws \
175+
--with-mpi=/opt/amazon/openmpi && \
176+
make -j$(nproc) install && \
177+
rm -rf /tmp/aws-ofi/nccl
178+
179+
# Do this to minimize the ld path env vars that users need to define when running this image.
180+
RUN echo "/usr/local/lib" >> /etc/ld.so.conf.d/local.conf && \
181+
echo "/opt/amazon/openmpi/lib" >> /etc/ld.so.conf.d/efa.conf && \
182+
ldconfig
183+
184+
ENV OMPI_MCA_pml=^cm,ucx \
185+
OMPI_MCA_btl=tcp,self \
186+
OMPI_MCA_btl_tcp_if_exclude=lo,docker0 \
187+
OPAL_PREFIX=/opt/amazon/openmpi \
188+
# https://discuss.pytorch.org/t/nccl-network-is-unreachable-connection-refused-when-initializing-ddp/137352
189+
# https://github.com/pytorch/pytorch/issues/68893
190+
NCCL_SOCKET_IFNAME=^docker,lo
191+
192+
ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
193+
194+
# NCCL-tests: always good to include this as a diagnostic tool.
195+
RUN git clone https://github.com/NVIDIA/nccl-tests.git /opt/nccl-tests \
196+
&& cd /opt/nccl-tests \
197+
&& git checkout ${NCCL_TESTS_VERSION} \
198+
&& make MPI=1 \
199+
MPI_HOME=/opt/amazon/openmpi \
200+
CUDA_HOME=/usr/local/cuda \
201+
NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90 -gencode=arch=compute_80,code=sm_80"
202+
203+
204+
####################################################################################################
205+
# Custom packages. Disable as you like. NOTE: always check `pip list` what's been installed. For
206+
# example, the base container comes pre-installed with Transformer Engine, flash attention, triton
207+
# (https://github.com/openai/triton/), etc.
208+
####################################################################################################
209+
# Install the xformers dependency from source, because pip install either breaks or try to pull
210+
# its own pt + cuda.
211+
#
212+
# Pre-requisite: build node has enough memory to compile xformers. More info on the stanza.
213+
RUN export TORCH_CUDA_ARCH_LIST="8.0;9.0+PTX" && \
214+
# On p4de.24xlarge:
215+
# - MAX_JOBS=16 => 145GB memory
216+
# - MAX_JOBS=32 => 241GB memory
217+
# - MAX_JOBS=48 => 243GB memory, 542.5s
218+
#
219+
# NOTE: must export MAX_JOBS. For some reason, `MAX_JOBS=16 pip install ...` doesn't seem to
220+
# work to prevent OOM.
221+
export MAX_JOBS=32 && \
222+
export NVCC_PREPEND_FLAGS="-t 32" && \
223+
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
224+
225+
RUN pip install transformers datasets
226+
227+
WORKDIR "/fsx"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/bin/bash
2+
3+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
# SPDX-License-Identifier: MIT-0
5+
6+
#SBATCH --nodes=2 # number of nodes to use
7+
#SBATCH --job-name=FSDP # name of your job
8+
#SBATCH --exclusive # job has exclusive use of the resource, no sharing
9+
10+
set -ex;
11+
12+
###########################
13+
###### User Variables #####
14+
###########################
15+
16+
GPUS_PER_NODE=8 # 4 for G5.12x, 8 for P4/P5
17+
18+
###########################
19+
## Environment Variables ##
20+
###########################
21+
22+
## Plenty of EFA level variables
23+
## Comment out for non-efa instances (G4d, P3)
24+
## For G5.12x, Comment out RDMA and Fork safe
25+
## For G4dn and other G5, comment out all
26+
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
27+
export FI_EFA_FORK_SAFE=1
28+
# export FI_LOG_LEVEL=1
29+
export FI_PROVIDER=efa
30+
# export NCCL_DEBUG=INFO
31+
## Switching SYNC_MEMOPS to zero can boost throughput with FSDP
32+
## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS
33+
## Reduces memory synchronizations
34+
## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html
35+
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
36+
37+
# default variables for Enroot
38+
: "${IMAGE:=$(pwd)/transformer-engine.sqsh}"
39+
: "${DATA_PATH:=/fsx}"
40+
: "${FSX_MOUNT:=$(pwd):$DATA_PATH}"
41+
42+
declare -a ARGS=(
43+
--container-image $IMAGE
44+
--container-mounts $FSX_MOUNT
45+
)
46+
47+
###########################
48+
####### Torch Dist #######
49+
###########################
50+
51+
declare -a TORCHRUN_ARGS=(
52+
--nproc_per_node=$GPUS_PER_NODE
53+
--nnodes=$SLURM_JOB_NUM_NODES
54+
--rdzv_id=$SLURM_JOB_ID
55+
--rdzv_backend=c10d
56+
--rdzv_endpoint=$(hostname)
57+
)
58+
59+
export TORCHRUN=torchrun
60+
export TRAIN_SCRIPT=./train.py
61+
62+
############################
63+
# Llama 2 Training Params ##
64+
############################
65+
66+
declare -a TRAINING_ARGS=(
67+
--max_context_width=4096
68+
--num_key_value_heads=32 # 7b: 32 13b: 40 70b: 8
69+
--intermediate_size=11008 # 7b: 11008 13b: 13824 70b: 28672
70+
--hidden_width=4096 # 7b: 4096 13b: 5120 70b: 8192
71+
--num_layers=32 # 7b: 32 13b: 40 70b: 80
72+
--num_heads=32 # 7b: 32 13b: 40 70b: 64
73+
--model_type=llama_v2
74+
--tokenizer="hf-internal-testing/llama-tokenizer"
75+
--checkpoint_freq=5000
76+
--validation_freq=100
77+
--max_steps=5000
78+
--checkpoint_dir=./checkpoints
79+
--dataset='c4'
80+
--dataset_config_name='en'
81+
--resume_from_checkpoint=./checkpoints
82+
--train_batch_size=1
83+
--val_batch_size=1
84+
--sharding_strategy="full" # https://pytorch.org/docs/stable/fsdp.html
85+
--offload_activations=1
86+
--fp8=1
87+
)
88+
89+
AUTO_RESUME=""
90+
if [ -d "/opt/sagemaker_cluster" ]; then
91+
echo "Detected Hyperpod cluster.. enabling --auto-resume=1"
92+
AUTO_RESUME="--auto-resume=1"
93+
fi
94+
95+
srun ${AUTO_RESUME} -l "${ARGS[@]}" torchrun "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}"

0 commit comments

Comments
 (0)