Skip to content

Commit dcb6caf

Browse files
committed
Fix GPU tests
1 parent 0958c18 commit dcb6caf

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

tests/unit/test_instance.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
confidential_start,
2626
create,
2727
delete,
28+
gpu_create,
2829
list_instances,
2930
logs,
3031
reboot,
@@ -120,7 +121,7 @@ def dict_to_ci_multi_dict_proxy(d: dict) -> CIMultiDictProxy:
120121

121122

122123
def create_mock_fetch_latest_crn_version():
123-
return AsyncMock(return_value="123.420.69")
124+
return AsyncMock(return_value="0.0.0")
124125

125126

126127
@pytest.mark.asyncio
@@ -241,31 +242,37 @@ def create_mock_validate_ssh_pubkey_file():
241242
)
242243

243244

244-
def create_mock_fetch_crn_info():
245+
def mock_crn_info():
245246
mock_machine_info = dummy_machine_info()
246-
return AsyncMock(
247-
return_value=CRNInfo(
248-
hash=ItemHash(FAKE_CRN_HASH),
249-
name="Mock CRN",
250-
owner=FAKE_ADDRESS_EVM,
251-
url=FAKE_CRN_URL,
252-
ccn_hash=FAKE_CRN_HASH,
253-
status="linked",
254-
version="123.420.69",
255-
score=0.9,
256-
reward_address=FAKE_ADDRESS_EVM,
257-
stream_reward_address=mock_machine_info.reward_address,
258-
machine_usage=mock_machine_info.machine_usage,
259-
ipv6=True,
260-
qemu_support=True,
261-
confidential_computing=True,
262-
gpu_support=True,
263-
terms_and_conditions=FAKE_STORE_HASH,
264-
compatible_available_gpus=[dummy_gpu_device()],
265-
)
247+
return CRNInfo(
248+
hash=ItemHash(FAKE_CRN_HASH),
249+
name="Mock CRN",
250+
owner=FAKE_ADDRESS_EVM,
251+
url=FAKE_CRN_URL,
252+
ccn_hash=FAKE_CRN_HASH,
253+
status="linked",
254+
version="123.420.69",
255+
score=0.9,
256+
reward_address=FAKE_ADDRESS_EVM,
257+
stream_reward_address=mock_machine_info.reward_address,
258+
machine_usage=mock_machine_info.machine_usage,
259+
ipv6=True,
260+
qemu_support=True,
261+
confidential_computing=True,
262+
gpu_support=True,
263+
terms_and_conditions=FAKE_STORE_HASH,
264+
compatible_available_gpus=[gpu.dict() for gpu in mock_machine_info.machine_usage.gpu.available_devices],
266265
)
267266

268267

268+
def create_mock_fetch_crn_info():
269+
return AsyncMock(return_value=mock_crn_info())
270+
271+
272+
def create_mock_crn_table():
273+
return MagicMock(return_value=MagicMock(run_async=AsyncMock(return_value=(mock_crn_info(), 0))))
274+
275+
269276
def create_mock_fetch_vm_info():
270277
return AsyncMock(
271278
return_value=[FAKE_VM_HASH, {"crn_url": FAKE_CRN_URL, "allocation_type": help_strings.ALLOCATION_MANUAL}]
@@ -363,19 +370,6 @@ def create_mock_vm_coco_client():
363370
return mock_vm_coco_client_class, mock_vm_coco_client
364371

365372

366-
# TODO: GPU test requires a rework
367-
""" ( # gpu_superfluid_evm
368-
{
369-
"payment_type": "superfluid",
370-
"payment_chain": "BASE",
371-
"rootfs": "debian12",
372-
"crn_url": FAKE_CRN_URL,
373-
"gpu": True,
374-
},
375-
(FAKE_VM_HASH, FAKE_CRN_URL, "BASE"),
376-
), """
377-
378-
379373
@pytest.mark.parametrize(
380374
ids=[
381375
"regular_hold_evm",
@@ -384,7 +378,7 @@ def create_mock_vm_coco_client():
384378
"coco_hold_sol",
385379
"coco_hold_evm",
386380
"coco_superfluid_evm",
387-
# "gpu_superfluid_evm",
381+
"gpu_superfluid_evm",
388382
],
389383
argnames="args, expected",
390384
argvalues=[
@@ -443,6 +437,16 @@ def create_mock_vm_coco_client():
443437
},
444438
(FAKE_VM_HASH, FAKE_CRN_URL, "BASE"),
445439
),
440+
( # gpu_superfluid_evm
441+
{
442+
"payment_type": "superfluid",
443+
"payment_chain": "BASE",
444+
"rootfs": "debian12",
445+
"crn_url": FAKE_CRN_URL,
446+
"gpu": True,
447+
},
448+
(FAKE_VM_HASH, FAKE_CRN_URL, "BASE"),
449+
),
446450
],
447451
)
448452
@pytest.mark.asyncio
@@ -454,8 +458,11 @@ async def test_create_instance(args, expected):
454458
mock_client_class, mock_client = create_mock_client(payment_type=args["payment_type"])
455459
mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account, payment_type=args["payment_type"])
456460
mock_vm_client_class, mock_vm_client = create_mock_vm_client()
461+
mock_validated_prompt = MagicMock(return_value="1")
457462
mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version()
458463
mock_fetch_crn_info = create_mock_fetch_crn_info()
464+
mock_crn_table = create_mock_crn_table()
465+
mock_yes_no_input = MagicMock(side_effect=[False, True, True])
459466
mock_wait_for_processed_instance = AsyncMock()
460467
mock_wait_for_confirmed_flow = AsyncMock()
461468

@@ -464,8 +471,11 @@ async def test_create_instance(args, expected):
464471
@patch("aleph_client.commands.instance.get_balance", mock_get_balance)
465472
@patch("aleph_client.commands.instance.AlephHttpClient", mock_client_class)
466473
@patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class)
474+
@patch("aleph_client.commands.pricing.validated_prompt", mock_validated_prompt)
467475
@patch("aleph_client.commands.instance.network.fetch_latest_crn_version", mock_fetch_latest_crn_version)
468476
@patch("aleph_client.commands.instance.fetch_crn_info", mock_fetch_crn_info)
477+
@patch("aleph_client.commands.instance.CRNTable", mock_crn_table)
478+
@patch("aleph_client.commands.instance.yes_no_input", mock_yes_no_input)
469479
@patch("aleph_client.commands.instance.wait_for_processed_instance", mock_wait_for_processed_instance)
470480
@patch.object(asyncio, "sleep", AsyncMock())
471481
@patch("aleph_client.commands.instance.wait_for_confirmed_flow", mock_wait_for_confirmed_flow)
@@ -498,7 +508,10 @@ async def create_instance(instance_spec):
498508
# CRN related assertions
499509
if args["payment_type"] == "superfluid" or args.get("confidential") or args.get("gpu"):
500510
mock_fetch_latest_crn_version.assert_called()
501-
mock_fetch_crn_info.assert_called_once()
511+
if not args.get("gpu"):
512+
mock_fetch_crn_info.assert_called_once()
513+
else:
514+
mock_crn_table.return_value.run_async.assert_called_once()
502515
mock_wait_for_processed_instance.assert_called_once()
503516
mock_vm_client.start_instance.assert_called_once()
504517
assert returned == expected
@@ -770,3 +783,16 @@ async def coco_create(instance_spec):
770783
mock_allocate.assert_called_once()
771784
mock_confidential_init_session.assert_called_once()
772785
mock_confidential_start.assert_called_once()
786+
787+
788+
@pytest.mark.asyncio
789+
async def test_gpu_create():
790+
mock_create = AsyncMock(return_value=[FAKE_VM_HASH, FAKE_CRN_URL, "AVAX"])
791+
792+
@patch("aleph_client.commands.instance.create", mock_create)
793+
async def gpu_instance():
794+
print() # For better display when pytest -v -s
795+
await gpu_create()
796+
mock_create.assert_called_once()
797+
798+
await gpu_instance()

0 commit comments

Comments
 (0)