Skip to content

Adding LWS Integration #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 198 additions & 2 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@
from axlearn.cloud.common.utils import generate_job_name, subprocess_run
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone
from axlearn.cloud.gcp.jobset_utils import BaseReplicatedJob
from axlearn.cloud.gcp.utils import custom_jobset_kwargs, delete_k8s_jobset
from axlearn.common.config import REQUIRED, ConfigOr, Required, config_class, maybe_instantiate
from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate
from axlearn.cloud.gcp.utils import (
custom_jobset_kwargs,
custom_leaderworkerset_kwargs,
delete_k8s_jobset,
delete_k8s_leaderworkerset,
delete_k8s_service,
)
from axlearn.common.config import REQUIRED, ConfigBase, ConfigOr, Required, config_class, maybe_instantiate
from axlearn.common.utils import Nested


Expand Down Expand Up @@ -267,3 +274,192 @@ def docker_command(
)
logging.debug("Docker run command: %s", cmd)
return cmd


class GKELeaderWorkerSet(GCPJob):
"""Base GKE LeaderWorkerSet interface"""

@config_class
class Config(GCPJob.Config):
"""Configures GKELeaderWorkerSet.
Attributes:
builder: A builder that returns one or more statefulset specs.
namespace: The namespace to use within the k8s cluster.
annotations: LeaderWorkerSet annotations.
"""

builder: Required[BaseLeaderWorkerTemplate.Config] = REQUIRED
namespace: str = "default"
annotations: Optional[ConfigOr[dict]] = None
num_replicas: int = 1

@classmethod
def set_defaults(cls, fv):
super().set_defaults(fv)
fv.set_default("max_tries", fv.max_tries or 10)
fv.set_default("retry_interval", fv.retry_interval or 60)

@classmethod
def define_flags(cls, fv: flags.FlagValues):
super().define_flags(fv)
common_kwargs = dict(flag_values=fv, allow_override=True)
flags.DEFINE_string("name", None, "Name of the LeaderWorkerSet.", **common_kwargs)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs):
cfg: GKELeaderWorkerSet.Config = super().from_flags(fv, **kwargs)
cfg.num_replicas = fv.num_replicas
return cfg

def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
super().__init__(cfg)
cfg: GKELeaderWorkerSet.Config = self.config
self._bundler = bundler
# This instantiatees a builder for constructing replicated job specs, which will be managed
# together under the leaderworkerset represented by this class.
# Note the distinction from bundlers, which are responsible for bundling any code assets
# required to run the job.
self._builder: BaseLeaderWorkerTemplate = cfg.builder.instantiate(bundler=bundler)

def _delete(self):
cfg: GKELeaderWorkerSet.Config = self.config
# Issues a delete request for the LeaderWorkerSet and proactively delete its descendants.
# This is not fully blocking; after the call returns there can be a delay before
# everything is deleted.
delete_k8s_leaderworkerset(cfg.name, namespace=cfg.namespace)
delete_k8s_service(cfg.name+"-service", namespace=cfg.namespace)

def _build_leaderworkerset(self) -> Nested[Any]:
"""
Builds a config for a LeaderWorkerSet, which is a set for multi-host inference

Returns:
A nested dict corresponding to a k8s LWS config
"""
cfg: GKELeaderWorkerSet.Config = self.config
annotations = maybe_instantiate(cfg.annotations or {})

return dict(
metadata=dict(name=cfg.name, annotations=annotations),
spec=dict(
replicas=cfg.num_replicas,
leaderWorkerTemplate=self._builder(),
),
)

def _execute(self):
cfg: GKELeaderWorkerSet.Config = self.config

#### Creating a Service #######
service = Service(cfg)
resp = service.execute(cfg)
logging.info("Service created %s", str(resp))

api_kwargs = custom_leaderworkerset_kwargs()
custom_object = dict(
apiVersion=f"{api_kwargs['group']}/{api_kwargs['version']}",
kind="LeaderWorkerSet",
**self._build_leaderworkerset(),
)
logging.info("submitting LeaderWorkerSet: %s", custom_object)
return k8s.client.CustomObjectsApi().create_namespaced_custom_object(
namespace=cfg.namespace,
body=custom_object,
**api_kwargs,
)


def exclusive_topology_annotations_leaderworkerset() -> dict:
"""Used for TPU GKELeaderWorkerSet.

The exclusive topology annotation will ensure that all Pods will have affinity
rules added that will ensure that they are fully scheduled on the same pod-slice
node-pools.
"""
return {"leaderworkerset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"}


class Service():
""" Service interface"""

@config_class
class Config(ConfigBase):
"""Configures Service
Attributes:
builder: A builder that returns one or more statefulset specs.
namespace: The namespace to use within the k8s cluster.
annotations: LeaderWorkerSet annotations
"""

name: Required[str] = None
# protocol: Optional[str] = "TCP"
# port: Required[str] = "8000"
# targetPort: Required[str] = "8000"

@classmethod
def define_flags(cls, fv: flags.FlagValues):
common_kwargs = dict(flag_values=fv, allow_override=True)
flags.DEFINE_string("name", None, "Name of the service.", **common_kwargs)

@classmethod
def set_defaults(cls, fv: flags.FlagValues):
super().set_defaults(fv)
fv.set_default("name", fv.name or generate_job_name())


def __init__(self,cfg: Config):
#super().__init__(cfg)
cfg = cfg
# name = cfg.name+"-service"
# protocol = "TCP"
# port = "8000"
# targetPort = "8000"

logging.info("service class init")
#self._bundler = bundler
# This instantiatees a builder for constructing replicated job specs, which will be managed
# together under the leaderworkerset represented by this class.
# Note the distinction from bundlers, which are responsible for bundling any code assets
# required to run the job.
#self._builder: BaseLeaderWorkerTemplate = cfg.builder.instantiate(bundler=bundler)

def _delete(self):
cfg: Service.Config = self.config
# Issues a delete request for the LeaderWorkerSet and proactively delete its descendants.
# This is not fully blocking; after the call returns there can be a delay before
# everything is deleted.

delete_k8s_service(cfg.name, namespace=cfg.namespace)

def _build_service(self,cfg) -> Nested[Any]:
"""
Builds a config for a Service

Returns:
A nested dict corresponding to a k8s Service config
"""
#cfg: Service.Config = self.config
logging.info("service class build")
#annotations = maybe_instantiate(self.cfg.annotations or {})

return dict(
metadata=k8s.client.V1ObjectMeta(name=cfg.name+"-service"),
spec=k8s.client.V1ServiceSpec(
selector={"app": cfg.name},
ports=[k8s.client.V1ServicePort(
protocol="TCP",
port=8000,
target_port=8000,
)],
type="ClusterIP" # or "NodePort" or "LoadBalancer"
))

def execute(self,cfg):
#cfg: Service.Config = self.config
logging.info("service class execute")
service = self._build_service(cfg)
logging.info("Submitting service body=%s ", service)
v1 = k8s.client.CoreV1Api()
return v1.create_namespaced_service(namespace="default", body=service)


87 changes: 86 additions & 1 deletion axlearn/cloud/gcp/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from axlearn.cloud.common.bundler import Bundler
from axlearn.cloud.common.utils import define_flags, from_flags
from axlearn.cloud.gcp import bundler, job, jobset_utils
from axlearn.cloud.gcp import bundler, job, jobset_utils, lws_utils
from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler, CloudBuildBundler
from axlearn.cloud.gcp.test_utils import default_mock_settings, mock_gcp_settings
from axlearn.common.config import REQUIRED, Required, config_class
Expand Down Expand Up @@ -211,3 +211,88 @@ def test_build_jobset(
self.assertNotIn("kueue.x-k8s.io/queue-name", jobset_annotations)
else:
self.assertEqual(jobset_annotations["kueue.x-k8s.io/queue-name"], queue)


class TPUGKELeaderWorkerSet(TestCase):
"""Tests GKELeaderWorkerSet with TPU."""

def run(self, result=None):
# Run tests under mock user and settings.
self._settings = default_mock_settings()
with mock_gcp_settings(
[lws_utils.__name__, bundler.__name__],
settings=self._settings,
):
return super().run(result)

def _job_config(
self,
*,
command: str,
bundler_cls: type[Bundler],
**kwargs,
) -> tuple[job.GKELeaderWorkerSet.Config, Bundler.Config]:
fv = flags.FlagValues()
cfg = job.GKELeaderWorkerSet.default_config().set(
builder=lws_utils.TPULeaderWorkerTemplate.default_config()
)
define_flags(cfg, fv)
for key, value in kwargs.items():
if value is not None:
# Use setattr rather than set_default to set flags.
setattr(fv, key, value)
fv.name = "fake-name"
fv.output_dir = "FAKE"
fv.instance_type = "tpu-v4-8"
fv.mark_as_parsed()
from_flags(cfg, fv, command=command)
# Test that retries are configured on fv by default.
self.assertIsNotNone(fv["max_tries"].default)
self.assertIsNotNone(fv["retry_interval"].default)
bundler_cfg = bundler_cls.from_spec([], fv=fv).set(image="test-image")
return cfg, bundler_cfg

@parameterized.product(
reservation=[None, "test"],
bundler_cls=[ArtifactRegistryBundler, CloudBuildBundler],
wrap_bundler=[False, True],
)
def test_instantiate(
self,
reservation,
bundler_cls: type[Bundler],
wrap_bundler,
):
class WrappedBundler(Bundler):
@config_class
class Config(Bundler.Config):
inner: Required[Bundler.Config] = REQUIRED

cfg, bundler_cfg = self._job_config(
command="test-command",
bundler_cls=bundler_cls,
reservation=reservation,
num_replicas=1,
)

self.assertIsInstance(cfg.builder, lws_utils.TPULeaderWorkerTemplate.Config)
cfg.builder = cast(lws_utils.TPULeaderWorkerTemplate.Config, cfg.builder)

self.assertEqual(cfg.name, cfg.builder.name)
self.assertEqual(cfg.project, self._settings["project"])
self.assertEqual(cfg.zone, self._settings["zone"])
self.assertEqual(cfg.builder.reservation, reservation or self._settings["gke_reservation"])
self.assertEqual(cfg.num_replicas, 1)
# Should work with wrapped bundlers.
if wrap_bundler:
bundler_cfg = WrappedBundler.default_config().set(inner=bundler_cfg)
gke_job = cfg.instantiate(bundler=bundler_cfg.instantiate())
self.assertEqual("v4-8", gke_job._builder._tpu_type)

def test_delete(self):
patch_delete = mock.patch(f"{job.__name__}.delete_k8s_leaderworkerset")
with patch_delete as mock_delete:
cfg, _ = self._job_config(command="test-command", bundler_cls=CloudBuildBundler)
gke_job = cfg.instantiate(bundler=mock.Mock())
gke_job._delete() # pylint: disable=protected-access
mock_delete.assert_called()
Loading