Skip to content

Commit effa2ea

Browse files
authored
Improve GPU selection (#342)
- Better CRN x GPU selection (all at once) - Rework on GPU tests - Minor fixes & improvements
1 parent 41575ac commit effa2ea

File tree

3 files changed

+227
-288
lines changed

3 files changed

+227
-288
lines changed

src/aleph_client/commands/instance/__init__.py

Lines changed: 124 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
setup_logging,
7474
str_to_datetime,
7575
validate_ssh_pubkey_file,
76-
validated_int_prompt,
7776
validated_prompt,
7877
wait_for_confirmed_flow,
7978
wait_for_processed_instance,
@@ -361,7 +360,7 @@ async def create(
361360
raise typer.Exit(code=1) from e
362361

363362
stream_reward_address = None
364-
crn = None
363+
crn, gpu_id = None, None
365364
if is_stream or confidential or gpu:
366365
if crn_url:
367366
try:
@@ -399,10 +398,11 @@ async def create(
399398
only_gpu=gpu,
400399
only_gpu_model=gpu_model,
401400
)
402-
crn = await crn_table.run_async()
403-
if not crn:
401+
selection = await crn_table.run_async()
402+
if not selection:
404403
# User has ctrl-c
405404
raise typer.Exit(1)
405+
crn, gpu_id = selection
406406
crn.display_crn_specs()
407407
if not yes_no_input("Deploy on this node?", default=True):
408408
crn = None
@@ -431,28 +431,14 @@ async def create(
431431
if not safe_getattr(crn, "gpu_support"):
432432
echo("Selected CRN does not support GPU computing.")
433433
raise typer.Exit(1)
434-
if crn.machine_usage and crn.machine_usage.gpu:
435-
if len(crn.machine_usage.gpu.available_devices) < 1:
436-
echo("Selected CRN does not have any GPUs available.")
437-
raise typer.Exit(1)
438-
439-
table = Table(box=box.ROUNDED)
440-
table.add_column("Id", style="white", overflow="fold")
441-
table.add_column("Vendor", style="blue")
442-
table.add_column("Model GPU", style="magenta")
443-
available_gpus = crn.machine_usage.gpu.available_devices
444-
for index, available_gpu in enumerate(available_gpus):
445-
table.add_row(str(index + 1), available_gpu.vendor, available_gpu.device_name)
446-
table.add_section()
447-
console.print(table)
448-
449-
selected_gpu_number = (
450-
validated_int_prompt("GPU Id to use", min_value=1, max_value=len(available_gpus)) - 1
451-
)
452-
selected_gpu = available_gpus[selected_gpu_number]
434+
if not crn.compatible_available_gpus:
435+
echo("Selected CRN does not have any GPU available.")
436+
raise typer.Exit(1)
437+
else:
438+
selected_gpu = crn.compatible_available_gpus[gpu_id]
453439
gpu_selection = Text.from_markup(
454-
f"[orange3]Vendor[/orange3]: {selected_gpu.vendor}\n[orange3]Model[/orange3]: "
455-
f"{selected_gpu.device_name}"
440+
f"[orange3]Vendor[/orange3]: {selected_gpu['vendor']}\n[orange3]Model[/orange3]: "
441+
f"{selected_gpu['model']}\n[orange3]Device[/orange3]: {selected_gpu['device_name']}"
456442
)
457443
console.print(
458444
Panel(
@@ -465,12 +451,15 @@ async def create(
465451
)
466452
gpu_requirement = [
467453
GpuProperties(
468-
vendor=selected_gpu.vendor,
469-
device_name=selected_gpu.device_name,
470-
device_class=selected_gpu.device_class,
471-
device_id=selected_gpu.device_id,
454+
vendor=selected_gpu["vendor"],
455+
device_name=selected_gpu["device_name"],
456+
device_class=selected_gpu["device_class"],
457+
device_id=selected_gpu["device_id"],
472458
)
473459
]
460+
if not yes_no_input("Confirm this GPU device?", default=True):
461+
echo("GPU device selection cancelled.")
462+
raise typer.Exit(1)
474463
if crn.terms_and_conditions:
475464
tac_accepted = await crn.display_terms_and_conditions(auto_accept=crn_auto_tac)
476465
if tac_accepted is None:
@@ -807,38 +796,42 @@ async def _show_instances(messages: builtins.list[InstanceMessage]):
807796
await fetch_crn_list() # Precache CRN list
808797
scheduler_responses = dict(await asyncio.gather(*[fetch_vm_info(message) for message in messages]))
809798
uninitialized_confidential_found = False
810-
for message in messages:
811-
info = scheduler_responses[message.item_hash]
812-
if info["confidential"] and info["ipv6_logs"] == help_strings.VM_NOT_READY:
813-
uninitialized_confidential_found = True
814-
name = Text(
815-
(
816-
message.content.metadata["name"]
817-
if hasattr(message.content, "metadata")
818-
and isinstance(message.content.metadata, dict)
819-
and "name" in message.content.metadata
820-
else "-"
821-
),
822-
style="magenta3",
823-
)
824-
link = f"https://explorer.aleph.im/address/ETH/{message.sender}/message/INSTANCE/{message.item_hash}"
825-
# link = f"{settings.API_HOST}/api/v0/messages/{message.item_hash}"
826-
item_hash_link = Text.from_markup(f"[link={link}]{message.item_hash}[/link]", style="bright_cyan")
827-
payment = Text.assemble(
828-
"Payment: ",
829-
Text(
830-
info["payment"].capitalize().ljust(12),
831-
style="red" if info["payment"] == PaymentType.hold.value else "orange3",
832-
),
833-
)
834-
confidential = Text.assemble(
835-
"Type: ", Text("Confidential", style="green") if info["confidential"] else Text("Regular", style="grey50")
836-
)
837-
chain = Text.assemble("Chain: ", Text(info["chain"].ljust(14), style="white"))
838-
created_at = Text.assemble(
839-
"Created at: ", Text(str(str_to_datetime(info["created_at"])).split(".", maxsplit=1)[0], style="orchid")
840-
)
841-
async with AlephHttpClient(api_server=settings.API_HOST) as client:
799+
async with AlephHttpClient(api_server=settings.API_HOST) as client:
800+
for message in messages:
801+
info = scheduler_responses[message.item_hash]
802+
if info["confidential"] and info["ipv6_logs"] == help_strings.VM_NOT_READY:
803+
uninitialized_confidential_found = True
804+
name = Text(
805+
(
806+
message.content.metadata["name"]
807+
if hasattr(message.content, "metadata")
808+
and isinstance(message.content.metadata, dict)
809+
and "name" in message.content.metadata
810+
else "-"
811+
),
812+
style="magenta3",
813+
)
814+
link = f"https://explorer.aleph.im/address/ETH/{message.sender}/message/INSTANCE/{message.item_hash}"
815+
# link = f"{settings.API_HOST}/api/v0/messages/{message.item_hash}"
816+
item_hash_link = Text.from_markup(f"[link={link}]{message.item_hash}[/link]", style="bright_cyan")
817+
payment = Text.assemble(
818+
"Payment: ",
819+
Text(
820+
info["payment"].capitalize().ljust(12),
821+
style="red" if info["payment"] == PaymentType.hold.value else "orange3",
822+
),
823+
)
824+
confidential = Text.assemble(
825+
"Type: ",
826+
Text("Confidential", style="green") if info["confidential"] else Text("Regular", style="grey50"),
827+
)
828+
chain = Text.assemble("Chain: ", Text(info["chain"].ljust(14), style="white"))
829+
created_at = Text.assemble(
830+
"Created at: ", Text(str(str_to_datetime(info["created_at"])).split(".", maxsplit=1)[0], style="orchid")
831+
)
832+
payer: Union[str, Text] = ""
833+
if message.sender != message.content.address:
834+
payer = Text.assemble("\nPayer: ", Text(str(message.content.address), style="orange1"))
842835
price: PriceResponse = await client.get_program_price(message.item_hash)
843836
required_tokens = Decimal(price.required_tokens)
844837
if price.payment_type == PaymentType.hold.value:
@@ -849,80 +842,81 @@ async def _show_instances(messages: builtins.list[InstanceMessage]):
849842
pday = f"{displayable_amount(86400*required_tokens, decimals=3)}/day"
850843
pmonth = f"{displayable_amount(2628000*required_tokens, decimals=3)}/month"
851844
aleph_price = Text.assemble(psec, " | ", phour, " | ", pday, " | ", pmonth, style="violet")
852-
cost = Text.assemble("\n$ALEPH: ", aleph_price)
853-
payer: Union[str, Text] = ""
854-
if message.sender != message.content.address:
855-
payer = Text.assemble("\nPayer: ", Text(str(message.sender), style="orange1"))
856-
instance = Text.assemble(
857-
"Item Hash ↓\t Name: ",
858-
name,
859-
"\n",
860-
item_hash_link,
861-
"\n",
862-
payment,
863-
confidential,
864-
"\n",
865-
chain,
866-
created_at,
867-
cost,
868-
payer,
869-
)
870-
hypervisor = safe_getattr(message, "content.environment.hypervisor")
871-
specs = [
872-
f"vCPU: [magenta3]{message.content.resources.vcpus}[/magenta3]\n",
873-
f"RAM: [magenta3]{message.content.resources.memory / 1_024:.2f} GiB[/magenta3]\n",
874-
f"Disk: [magenta3]{message.content.rootfs.size_mib / 1_024:.2f} GiB[/magenta3]\n",
875-
f"HyperV: [magenta3]{hypervisor.capitalize() if hypervisor else 'Firecracker'}[/magenta3]",
876-
]
877-
gpus = safe_getattr(message, "content.requirements.gpu")
878-
if gpus:
879-
specs += [f"\n[bright_yellow]GPU [[green]{len(gpus)}[/green]]:\n"]
880-
for gpu in gpus:
881-
specs += [f"• [green]{gpu.vendor}, {gpu.device_name}[green]"]
882-
specs += ["[/bright_yellow]"]
883-
specifications = Text.from_markup("".join(specs))
884-
status_column = Text.assemble(
885-
Text.assemble(
886-
Text("Allocation: ", style="blue"),
887-
Text(
888-
info["allocation_type"] + "\n",
889-
style="magenta3" if info["allocation_type"] == help_strings.ALLOCATION_MANUAL else "deep_sky_blue1",
890-
),
891-
),
892-
(
845+
cost = Text.assemble("\n$ALEPH: ", aleph_price)
846+
instance = Text.assemble(
847+
"Item Hash ↓\t Name: ",
848+
name,
849+
"\n",
850+
item_hash_link,
851+
"\n",
852+
payment,
853+
confidential,
854+
"\n",
855+
chain,
856+
created_at,
857+
payer,
858+
cost,
859+
)
860+
hypervisor = safe_getattr(message, "content.environment.hypervisor")
861+
specs = [
862+
f"vCPU: [magenta3]{message.content.resources.vcpus}[/magenta3]\n",
863+
f"RAM: [magenta3]{message.content.resources.memory / 1_024:.2f} GiB[/magenta3]\n",
864+
f"Disk: [magenta3]{message.content.rootfs.size_mib / 1_024:.2f} GiB[/magenta3]\n",
865+
f"HyperV: [magenta3]{hypervisor.capitalize() if hypervisor else 'Firecracker'}[/magenta3]",
866+
]
867+
gpus = safe_getattr(message, "content.requirements.gpu")
868+
if gpus:
869+
specs += [f"\n[bright_yellow]GPU [[green]{len(gpus)}[/green]]:\n"]
870+
for gpu in gpus:
871+
specs += [f"• [green]{gpu.vendor}, {gpu.device_name}[green]"]
872+
specs += ["[/bright_yellow]"]
873+
specifications = Text.from_markup("".join(specs))
874+
status_column = Text.assemble(
893875
Text.assemble(
894-
Text("CRN Hash: ", style="blue"),
895-
Text(info["crn_hash"] + "\n", style=("bright_cyan")),
896-
)
897-
if info["crn_hash"]
898-
else ""
899-
),
900-
Text.assemble(
901-
Text("CRN Url: ", style="blue"),
902-
Text(
903-
info["crn_url"] + "\n",
904-
style="green1" if info["crn_url"].startswith("http") else "grey50",
876+
Text("Allocation: ", style="blue"),
877+
Text(
878+
info["allocation_type"] + "\n",
879+
style=(
880+
"magenta3"
881+
if info["allocation_type"] == help_strings.ALLOCATION_MANUAL
882+
else "deep_sky_blue1"
883+
),
884+
),
885+
),
886+
(
887+
Text.assemble(
888+
Text("CRN Hash: ", style="blue"),
889+
Text(info["crn_hash"] + "\n", style=("bright_cyan")),
890+
)
891+
if info["crn_hash"]
892+
else ""
905893
),
906-
),
907-
Text.assemble(
908-
Text("IPv6: ", style="blue"),
909-
Text(info["ipv6_logs"]),
910-
style="bright_yellow" if len(info["ipv6_logs"].split(":")) == 8 else "dark_orange",
911-
),
912-
(
913894
Text.assemble(
914-
Text(f"\n[{'✅' if info['tac_accepted'] else '❌'}] Accepted Terms & Conditions: "),
895+
Text("CRN Url: ", style="blue"),
915896
Text(
916-
f"{info['tac_url']}",
917-
style="orange1",
897+
info["crn_url"] + "\n",
898+
style="green1" if info["crn_url"].startswith("http") else "grey50",
918899
),
919-
)
920-
if info["tac_hash"]
921-
else ""
922-
),
923-
)
924-
table.add_row(instance, specifications, status_column)
925-
table.add_section()
900+
),
901+
Text.assemble(
902+
Text("IPv6: ", style="blue"),
903+
Text(info["ipv6_logs"]),
904+
style="bright_yellow" if len(info["ipv6_logs"].split(":")) == 8 else "dark_orange",
905+
),
906+
(
907+
Text.assemble(
908+
Text(f"\n[{'✅' if info['tac_accepted'] else '❌'}] Accepted Terms & Conditions: "),
909+
Text(
910+
f"{info['tac_url']}",
911+
style="orange1",
912+
),
913+
)
914+
if info["tac_hash"]
915+
else ""
916+
),
917+
)
918+
table.add_row(instance, specifications, status_column)
919+
table.add_section()
926920

927921
console = Console()
928922
console.print(table)

0 commit comments

Comments
 (0)