25
25
confidential_start ,
26
26
create ,
27
27
delete ,
28
+ gpu_create ,
28
29
list_instances ,
29
30
logs ,
30
31
reboot ,
@@ -120,7 +121,7 @@ def dict_to_ci_multi_dict_proxy(d: dict) -> CIMultiDictProxy:
120
121
121
122
122
123
def create_mock_fetch_latest_crn_version ():
123
- return AsyncMock (return_value = "123.420.69 " )
124
+ return AsyncMock (return_value = "0.0.0 " )
124
125
125
126
126
127
@pytest .mark .asyncio
@@ -241,31 +242,37 @@ def create_mock_validate_ssh_pubkey_file():
241
242
)
242
243
243
244
244
- def create_mock_fetch_crn_info ():
245
+ def mock_crn_info ():
245
246
mock_machine_info = dummy_machine_info ()
246
- return AsyncMock (
247
- return_value = CRNInfo (
248
- hash = ItemHash (FAKE_CRN_HASH ),
249
- name = "Mock CRN" ,
250
- owner = FAKE_ADDRESS_EVM ,
251
- url = FAKE_CRN_URL ,
252
- ccn_hash = FAKE_CRN_HASH ,
253
- status = "linked" ,
254
- version = "123.420.69" ,
255
- score = 0.9 ,
256
- reward_address = FAKE_ADDRESS_EVM ,
257
- stream_reward_address = mock_machine_info .reward_address ,
258
- machine_usage = mock_machine_info .machine_usage ,
259
- ipv6 = True ,
260
- qemu_support = True ,
261
- confidential_computing = True ,
262
- gpu_support = True ,
263
- terms_and_conditions = FAKE_STORE_HASH ,
264
- compatible_available_gpus = [dummy_gpu_device ()],
265
- )
247
+ return CRNInfo (
248
+ hash = ItemHash (FAKE_CRN_HASH ),
249
+ name = "Mock CRN" ,
250
+ owner = FAKE_ADDRESS_EVM ,
251
+ url = FAKE_CRN_URL ,
252
+ ccn_hash = FAKE_CRN_HASH ,
253
+ status = "linked" ,
254
+ version = "123.420.69" ,
255
+ score = 0.9 ,
256
+ reward_address = FAKE_ADDRESS_EVM ,
257
+ stream_reward_address = mock_machine_info .reward_address ,
258
+ machine_usage = mock_machine_info .machine_usage ,
259
+ ipv6 = True ,
260
+ qemu_support = True ,
261
+ confidential_computing = True ,
262
+ gpu_support = True ,
263
+ terms_and_conditions = FAKE_STORE_HASH ,
264
+ compatible_available_gpus = [gpu .dict () for gpu in mock_machine_info .machine_usage .gpu .available_devices ],
266
265
)
267
266
268
267
268
+ def create_mock_fetch_crn_info ():
269
+ return AsyncMock (return_value = mock_crn_info ())
270
+
271
+
272
+ def create_mock_crn_table ():
273
+ return MagicMock (return_value = MagicMock (run_async = AsyncMock (return_value = (mock_crn_info (), 0 ))))
274
+
275
+
269
276
def create_mock_fetch_vm_info ():
270
277
return AsyncMock (
271
278
return_value = [FAKE_VM_HASH , {"crn_url" : FAKE_CRN_URL , "allocation_type" : help_strings .ALLOCATION_MANUAL }]
@@ -363,19 +370,6 @@ def create_mock_vm_coco_client():
363
370
return mock_vm_coco_client_class , mock_vm_coco_client
364
371
365
372
366
- # TODO: GPU test requires a rework
367
- """ ( # gpu_superfluid_evm
368
- {
369
- "payment_type": "superfluid",
370
- "payment_chain": "BASE",
371
- "rootfs": "debian12",
372
- "crn_url": FAKE_CRN_URL,
373
- "gpu": True,
374
- },
375
- (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"),
376
- ), """
377
-
378
-
379
373
@pytest .mark .parametrize (
380
374
ids = [
381
375
"regular_hold_evm" ,
@@ -384,7 +378,7 @@ def create_mock_vm_coco_client():
384
378
"coco_hold_sol" ,
385
379
"coco_hold_evm" ,
386
380
"coco_superfluid_evm" ,
387
- # "gpu_superfluid_evm",
381
+ "gpu_superfluid_evm" ,
388
382
],
389
383
argnames = "args, expected" ,
390
384
argvalues = [
@@ -443,6 +437,16 @@ def create_mock_vm_coco_client():
443
437
},
444
438
(FAKE_VM_HASH , FAKE_CRN_URL , "BASE" ),
445
439
),
440
+ ( # gpu_superfluid_evm
441
+ {
442
+ "payment_type" : "superfluid" ,
443
+ "payment_chain" : "BASE" ,
444
+ "rootfs" : "debian12" ,
445
+ "crn_url" : FAKE_CRN_URL ,
446
+ "gpu" : True ,
447
+ },
448
+ (FAKE_VM_HASH , FAKE_CRN_URL , "BASE" ),
449
+ ),
446
450
],
447
451
)
448
452
@pytest .mark .asyncio
@@ -454,8 +458,11 @@ async def test_create_instance(args, expected):
454
458
mock_client_class , mock_client = create_mock_client (payment_type = args ["payment_type" ])
455
459
mock_auth_client_class , mock_auth_client = create_mock_auth_client (mock_account , payment_type = args ["payment_type" ])
456
460
mock_vm_client_class , mock_vm_client = create_mock_vm_client ()
461
+ mock_validated_prompt = MagicMock (return_value = "1" )
457
462
mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version ()
458
463
mock_fetch_crn_info = create_mock_fetch_crn_info ()
464
+ mock_crn_table = create_mock_crn_table ()
465
+ mock_yes_no_input = MagicMock (side_effect = [False , True , True ])
459
466
mock_wait_for_processed_instance = AsyncMock ()
460
467
mock_wait_for_confirmed_flow = AsyncMock ()
461
468
@@ -464,8 +471,11 @@ async def test_create_instance(args, expected):
464
471
@patch ("aleph_client.commands.instance.get_balance" , mock_get_balance )
465
472
@patch ("aleph_client.commands.instance.AlephHttpClient" , mock_client_class )
466
473
@patch ("aleph_client.commands.instance.AuthenticatedAlephHttpClient" , mock_auth_client_class )
474
+ @patch ("aleph_client.commands.pricing.validated_prompt" , mock_validated_prompt )
467
475
@patch ("aleph_client.commands.instance.network.fetch_latest_crn_version" , mock_fetch_latest_crn_version )
468
476
@patch ("aleph_client.commands.instance.fetch_crn_info" , mock_fetch_crn_info )
477
+ @patch ("aleph_client.commands.instance.CRNTable" , mock_crn_table )
478
+ @patch ("aleph_client.commands.instance.yes_no_input" , mock_yes_no_input )
469
479
@patch ("aleph_client.commands.instance.wait_for_processed_instance" , mock_wait_for_processed_instance )
470
480
@patch .object (asyncio , "sleep" , AsyncMock ())
471
481
@patch ("aleph_client.commands.instance.wait_for_confirmed_flow" , mock_wait_for_confirmed_flow )
@@ -498,7 +508,10 @@ async def create_instance(instance_spec):
498
508
# CRN related assertions
499
509
if args ["payment_type" ] == "superfluid" or args .get ("confidential" ) or args .get ("gpu" ):
500
510
mock_fetch_latest_crn_version .assert_called ()
501
- mock_fetch_crn_info .assert_called_once ()
511
+ if not args .get ("gpu" ):
512
+ mock_fetch_crn_info .assert_called_once ()
513
+ else :
514
+ mock_crn_table .return_value .run_async .assert_called_once ()
502
515
mock_wait_for_processed_instance .assert_called_once ()
503
516
mock_vm_client .start_instance .assert_called_once ()
504
517
assert returned == expected
@@ -770,3 +783,16 @@ async def coco_create(instance_spec):
770
783
mock_allocate .assert_called_once ()
771
784
mock_confidential_init_session .assert_called_once ()
772
785
mock_confidential_start .assert_called_once ()
786
+
787
+
788
+ @pytest .mark .asyncio
789
+ async def test_gpu_create ():
790
+ mock_create = AsyncMock (return_value = [FAKE_VM_HASH , FAKE_CRN_URL , "AVAX" ])
791
+
792
+ @patch ("aleph_client.commands.instance.create" , mock_create )
793
+ async def gpu_instance ():
794
+ print () # For better display when pytest -v -s
795
+ await gpu_create ()
796
+ mock_create .assert_called_once ()
797
+
798
+ await gpu_instance ()
0 commit comments