Skip to content

Better avoid and handle AWS API throttling #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 66 additions & 67 deletions dask_cloudprovider/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dask_cloudprovider.aws.helper import (
dict_to_aws,
aws_to_dict,
get_sleep_duration,
get_default_vpc,
get_vpc_subnets,
create_default_security_group,
Expand All @@ -25,6 +24,7 @@

try:
from botocore.exceptions import ClientError
from aiobotocore.config import AioConfig
from aiobotocore.session import get_session
except ImportError as e:
msg = (
Expand Down Expand Up @@ -120,6 +120,7 @@ def __init__(
fargate_use_private_ip=False,
fargate_capacity_provider=None,
task_kwargs=None,
is_task_long_arn_format_enabled=True,
**kwargs,
):
self.lock = asyncio.Lock()
Expand All @@ -144,6 +145,7 @@ def __init__(
self._fargate_capacity_provider = fargate_capacity_provider
self.kwargs = kwargs
self.task_kwargs = task_kwargs
self._is_task_long_arn_format_enabled = is_task_long_arn_format_enabled
self.status = Status.created

def __await__(self):
Expand All @@ -160,36 +162,15 @@ async def _():
def _use_public_ip(self):
return self.fargate and not self._fargate_use_private_ip

async def _is_long_arn_format_enabled(self):
async with self._client("ecs") as ecs:
[response] = (
await ecs.list_account_settings(
name="taskLongArnFormat", effectiveSettings=True
)
)["settings"]
return response["value"] == "enabled"

async def _update_task(self):
async with self._client("ecs") as ecs:
wait_duration = 1
while True:
try:
[self.task] = (
await ecs.describe_tasks(
cluster=self.cluster_arn, tasks=[self.task_arn]
)
)["tasks"]
except ClientError as e:
if e.response["Error"]["Code"] == "ThrottlingException":
wait_duration = min(wait_duration * 2, 20)
else:
raise
else:
break
await asyncio.sleep(wait_duration)
Comment on lines -175 to -189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right in understanding that we are removing our retry logic here and leveraging the built-in retries in aiobotocore?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Sorry, I stuck this draft PR up as a placeholder and hadn't yet added a detailed description to describe this. Done now.

[self.task] = (
await ecs.describe_tasks(
cluster=self.cluster_arn, tasks=[self.task_arn]
)
)["tasks"]

async def _task_is_running(self):
await self._update_task()
def _task_is_running(self):
return self.task["lastStatus"] == "RUNNING"

async def start(self):
Expand All @@ -199,7 +180,7 @@ async def start(self):
kwargs = self.task_kwargs.copy() if self.task_kwargs is not None else {}

# Tags are only supported if you opt into long arn format so we need to check for that
if await self._is_long_arn_format_enabled():
if self._is_task_long_arn_format_enabled:
kwargs["tags"] = dict_to_aws(self.tags)
if self.platform_version and self.fargate:
kwargs["platformVersion"] = self.platform_version
Expand Down Expand Up @@ -253,13 +234,19 @@ async def start(self):
[self.task] = response["tasks"]
break
except Exception as e:
# Retries due to throttle errors are handled by the aiobotocore client so this should be an uncommon case
timeout.set_exception(e)
await asyncio.sleep(1)
logger.debug(f"Failed to start {self.task_type} task after {timeout.elapsed_time:.1f}s, retrying in 1s: {e}")
await asyncio.sleep(2)

self.task_arn = self.task["taskArn"]

# Wait for the task to come up
while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]:
# Try to avoid hitting throttling rate limits when bring up a large cluster
await asyncio.sleep(1)
await self._update_task()
if not await self._task_is_running():
if not self._task_is_running():
raise RuntimeError("%s failed to start" % type(self).__name__)
[eni] = [
attachment
Expand All @@ -286,7 +273,7 @@ async def close(self, **kwargs):
async with self._client("ecs") as ecs:
await ecs.stop_task(cluster=self.cluster_arn, task=self.task_arn)
await self._update_task()
while self.task["lastStatus"] in ["RUNNING"]:
while self._task_is_running():
await asyncio.sleep(1)
await self._update_task()
self.status = Status.closed
Expand All @@ -304,48 +291,35 @@ def _log_stream_name(self):
)

async def logs(self, follow=False):
current_try = 0
next_token = None
read_from = 0

while True:
try:
async with self._client("logs") as logs:
if next_token:
l = await logs.get_log_events(
logGroupName=self.log_group,
logStreamName=self._log_stream_name,
nextToken=next_token,
)
else:
l = await logs.get_log_events(
logGroupName=self.log_group,
logStreamName=self._log_stream_name,
startTime=read_from,
)
if next_token != l["nextForwardToken"]:
next_token = l["nextForwardToken"]
async with self._client("logs") as logs:
if next_token:
l = await logs.get_log_events(
logGroupName=self.log_group,
logStreamName=self._log_stream_name,
nextToken=next_token,
)
else:
next_token = None
if not l["events"]:
if follow:
await asyncio.sleep(1)
else:
break
for event in l["events"]:
read_from = event["timestamp"]
yield event["message"]
except ClientError as e:
if e.response["Error"]["Code"] == "ThrottlingException":
warnings.warn(
"get_log_events rate limit exceeded, retrying after delay.",
RuntimeWarning,
l = await logs.get_log_events(
logGroupName=self.log_group,
logStreamName=self._log_stream_name,
startTime=read_from,
)
backoff_duration = get_sleep_duration(current_try)
await asyncio.sleep(backoff_duration)
current_try += 1
if next_token != l["nextForwardToken"]:
next_token = l["nextForwardToken"]
else:
next_token = None
if not l["events"]:
if follow:
await asyncio.sleep(1)
else:
raise
break
for event in l["events"]:
read_from = event["timestamp"]
yield event["message"]

def __repr__(self):
return "<ECS Task %s: status=%s>" % (type(self).__name__, self.status)
Expand Down Expand Up @@ -813,6 +787,7 @@ def __init__(
self._platform_version = platform_version
self._lock = asyncio.Lock()
self.session = get_session()
self._is_task_long_arn_format_enabled = None
super().__init__(**kwargs)

def _client(self, name: str):
Expand All @@ -821,6 +796,16 @@ def _client(self, name: str):
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
region_name=self._region_name,
config=AioConfig(
retries={
# Use Standard retry mode which provides:
# - Jittered exponential backoff with max of 20s in the event of failures
# - Never delays the first request attempt, only the retries
# - Supports circuit-breaking to prevent the SDK from retrying during outages
"mode": "standard",
"max_attempts": 10, # Not including the initial request
}
),
)

async def _start(
Expand Down Expand Up @@ -950,6 +935,10 @@ async def _start(
self.worker_task_definition_arn = (
await self._create_worker_task_definition_arn()
)
if self._is_task_long_arn_format_enabled is None:
self._is_task_long_arn_format_enabled = (
await self._get_is_task_long_arn_format_enabled()
)

options = {
"client": self._client,
Expand All @@ -962,6 +951,7 @@ async def _start(
"tags": self.tags,
"platform_version": self._platform_version,
"fargate_use_private_ip": self._fargate_use_private_ip,
"is_task_long_arn_format_enabled": self._is_task_long_arn_format_enabled,
}
scheduler_options = {
"task_definition_arn": self.scheduler_task_definition_arn,
Expand Down Expand Up @@ -1319,6 +1309,15 @@ async def _delete_worker_task_definition_arn(self):
taskDefinition=self.worker_task_definition_arn
)

async def _get_is_task_long_arn_format_enabled(self):
async with self._client("ecs") as ecs:
[response] = (
await ecs.list_account_settings(
name="taskLongArnFormat", effectiveSettings=True
)
)["settings"]
return response["value"] == "enabled"

def logs(self):
async def get_logs(task):
log = ""
Expand Down
9 changes: 8 additions & 1 deletion dask_cloudprovider/utils/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(self):
self.start = datetime.now()
self.running = True

if self.start + timedelta(seconds=self.timeout) < datetime.now():
if self.elapsed_time >= self.timeout:
if self.warn:
warnings.warn(self.error_message)
return False
Expand All @@ -82,3 +82,10 @@ def set_exception(self, e):
the thing you are trying rather than a TimeoutException.
"""
self.exception = e

@property
def elapsed_time(self):
"""Return the elapsed time since the timeout started."""
if self.start is None:
return 0
return (datetime.now() - self.start).total_seconds()
Loading