diff --git a/chia/_tests/cmds/cmd_test_utils.py b/chia/_tests/cmds/cmd_test_utils.py index 8666076e9304..8234fb838f03 100644 --- a/chia/_tests/cmds/cmd_test_utils.py +++ b/chia/_tests/cmds/cmd_test_utils.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Any, Optional, cast -from chia_rs import BlockRecord, Coin, G2Element +from chia_rs import BlockRecord, Coin, G1Element, G2Element from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint8, uint16, uint32, uint64 @@ -44,6 +44,10 @@ NFTGetInfo, NFTGetInfoResponse, SendTransactionMultiResponse, + SignMessageByAddress, + SignMessageByAddressResponse, + SignMessageByID, + SignMessageByIDResponse, WalletInfoResponse, ) from chia.wallet.wallet_rpc_client import WalletRpcClient @@ -147,19 +151,37 @@ async def get_cat_name(self, wallet_id: int) -> str: self.add_to_log("get_cat_name", (wallet_id,)) return "test" + str(wallet_id) - async def sign_message_by_address(self, address: str, message: str) -> tuple[str, str, str]: - self.add_to_log("sign_message_by_address", (address, message)) - pubkey = bytes([3] * 48).hex() - signature = bytes([6] * 576).hex() + async def sign_message_by_address(self, request: SignMessageByAddress) -> SignMessageByAddressResponse: + self.add_to_log("sign_message_by_address", (request.address, request.message)) + pubkey = G1Element.from_bytes( + bytes.fromhex( + "b5acf3599bc5fa5da1c00f6cc3d5bcf1560def67778b7f50a8c373a83f78761505b6250ab776e38a292e26628009aec4" + ) + ) + signature = G2Element.from_bytes( + bytes.fromhex( + "c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + ) + ) signing_mode = SigningMode.CHIP_0002.value - return pubkey, signature, signing_mode + return SignMessageByAddressResponse(pubkey, signature, signing_mode) - async def sign_message_by_id(self, id: str, message: str) -> tuple[str, str, str]: - self.add_to_log("sign_message_by_id", (id, message)) - pubkey = bytes([4] * 48).hex() - signature = bytes([7] * 576).hex() + async def sign_message_by_id(self, request: SignMessageByID) -> SignMessageByIDResponse: + self.add_to_log("sign_message_by_id", (request.id, request.message)) + pubkey = G1Element.from_bytes( + bytes.fromhex( + "a9e652cb551d5978a9ee4b7aa52a4e826078a54b08a3d903c38611cb8a804a9a29c926e4f8549314a079e04ecde10cc1" + ) + ) + signature = G2Element.from_bytes( + bytes.fromhex( + "c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + ) + ) signing_mode = SigningMode.CHIP_0002.value - return pubkey, signature, signing_mode + return SignMessageByIDResponse(pubkey, signature, bytes32.zeros, signing_mode) async def cat_asset_id_to_name(self, asset_id: bytes32) -> Optional[tuple[Optional[uint32], str]]: """ @@ -250,14 +272,6 @@ async def get_spendable_coins( unconfirmed_additions = [Coin(bytes32([7] * 32), bytes32([8] * 32), uint64(1234580000))] return confirmed_records, unconfirmed_removals, unconfirmed_additions - async def get_next_address(self, wallet_id: int, new_address: bool) -> str: - self.add_to_log("get_next_address", (wallet_id, new_address)) - addr = encode_puzzle_hash(bytes32([self.wallet_index] * 32), "xch") - self.wallet_index += 1 - if self.wallet_index > 254: - self.wallet_index = 1 - return addr - async def send_transaction_multi( self, wallet_id: int, diff --git a/chia/_tests/cmds/wallet/test_did.py b/chia/_tests/cmds/wallet/test_did.py index d017408d381d..5144cde7f71b 100644 --- a/chia/_tests/cmds/wallet/test_did.py +++ b/chia/_tests/cmds/wallet/test_did.py @@ -116,8 +116,8 @@ def test_did_sign_message(capsys: object, get_test_cli_clients: tuple[TestRpcCli # these are various things that should be in the output assert_list = [ f"Message: {message.hex()}", - f"Public Key: {bytes([4] * 48).hex()}", - f"Signature: {bytes([7] * 576).hex()}", + "Public Key: a9e652cb551d5978a9ee4b7aa52a4e826078a54b08a3d903c38611cb8a804a9a29c926e4f8549314a079e04ecde10cc1", + "Signature: c0" + "00" * (42 - 1), f"Signing Mode: {SigningMode.CHIP_0002.value}", ] run_cli_command_and_assert(capsys, root_dir, [*command_args, f"-i{did_id}"], assert_list) diff --git a/chia/_tests/cmds/wallet/test_nft.py b/chia/_tests/cmds/wallet/test_nft.py index 43f10363f273..1ff5fe62834e 100644 --- a/chia/_tests/cmds/wallet/test_nft.py +++ b/chia/_tests/cmds/wallet/test_nft.py @@ -64,19 +64,19 @@ def test_nft_sign_message(capsys: object, get_test_cli_clients: tuple[TestRpcCli inst_rpc_client = TestWalletRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client - did_id = encode_puzzle_hash(get_bytes32(1), "nft") + nft_id = encode_puzzle_hash(get_bytes32(1), "nft") message = b"hello nft world!!" - command_args = ["wallet", "did", "sign_message", FINGERPRINT_ARG, f"-m{message.hex()}"] + command_args = ["wallet", "nft", "sign_message", FINGERPRINT_ARG, f"-m{message.hex()}"] # these are various things that should be in the output assert_list = [ f"Message: {message.hex()}", - f"Public Key: {bytes([4] * 48).hex()}", - f"Signature: {bytes([7] * 576).hex()}", + "Public Key: a9e652cb551d5978a9ee4b7aa52a4e826078a54b08a3d903c38611cb8a804a9a29c926e4f8549314a079e04ecde10cc1", + "Signature: c0" + "00" * (42 - 1), f"Signing Mode: {SigningMode.CHIP_0002.value}", ] - run_cli_command_and_assert(capsys, root_dir, [*command_args, f"-i{did_id}"], assert_list) + run_cli_command_and_assert(capsys, root_dir, [*command_args, f"-i{nft_id}"], assert_list) expected_calls: logType = { - "sign_message_by_id": [(did_id, message.hex())], # xch std + "sign_message_by_id": [(nft_id, message.hex())], # xch std } test_rpc_clients.wallet_rpc_client.check_log(expected_calls) diff --git a/chia/_tests/cmds/wallet/test_notifications.py b/chia/_tests/cmds/wallet/test_notifications.py index 40bf0afde1f4..cde0ca62b0c1 100644 --- a/chia/_tests/cmds/wallet/test_notifications.py +++ b/chia/_tests/cmds/wallet/test_notifications.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, cast +from typing import cast from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint32, uint64 @@ -12,7 +12,7 @@ from chia.wallet.conditions import ConditionValidTimes from chia.wallet.notification_store import Notification from chia.wallet.transaction_record import TransactionRecord -from chia.wallet.wallet_request_types import GetNotifications, GetNotificationsResponse +from chia.wallet.wallet_request_types import DeleteNotifications, GetNotifications, GetNotificationsResponse test_condition_valid_times: ConditionValidTimes = ConditionValidTimes(min_time=uint64(100), max_time=uint64(150)) @@ -111,15 +111,31 @@ def test_notifications_delete(capsys: object, get_test_cli_clients: tuple[TestRp # set RPC Client class NotificationsDeleteRpcClient(TestWalletRpcClient): - async def delete_notifications(self, ids: Optional[list[bytes32]] = None) -> bool: - self.add_to_log("delete_notifications", (ids,)) - return True + async def delete_notifications(self, request: DeleteNotifications) -> None: + self.add_to_log("delete_notifications", (request.ids,)) inst_rpc_client = NotificationsDeleteRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client + # Try all first command_args = ["wallet", "notifications", "delete", FINGERPRINT_ARG, "--all"] # these are various things that should be in the output - assert_list = ["Success: True"] + assert_list = ["Success!"] run_cli_command_and_assert(capsys, root_dir, command_args, assert_list) expected_calls: logType = {"delete_notifications": [(None,)]} test_rpc_clients.wallet_rpc_client.check_log(expected_calls) + # Next try specifying IDs + command_args = [ + "wallet", + "notifications", + "delete", + FINGERPRINT_ARG, + "--id", + bytes32.zeros.hex(), + "--id", + bytes32.zeros.hex(), + ] + # these are various things that should be in the output + assert_list = ["Success!"] + run_cli_command_and_assert(capsys, root_dir, command_args, assert_list) + expected_calls = {"delete_notifications": [([bytes32.zeros, bytes32.zeros],)]} + test_rpc_clients.wallet_rpc_client.check_log(expected_calls) diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index f0a384196c8d..86fdd159aabf 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -45,8 +45,14 @@ CancelOfferResponse, CATSpendResponse, CreateOfferForIDsResponse, + DeleteUnconfirmedTransactions, + ExtendDerivationIndex, + ExtendDerivationIndexResponse, FungibleAsset, + GetCurrentDerivationIndexResponse, GetHeightInfoResponse, + GetNextAddress, + GetNextAddressResponse, GetTransaction, GetTransactions, GetTransactionsResponse, @@ -493,11 +499,11 @@ def test_get_address(capsys: object, get_test_cli_clients: tuple[TestRpcClients, # set RPC Client class GetAddressWalletRpcClient(TestWalletRpcClient): - async def get_next_address(self, wallet_id: int, new_address: bool) -> str: - self.add_to_log("get_next_address", (wallet_id, new_address)) - if new_address: - return encode_puzzle_hash(get_bytes32(3), "xch") - return encode_puzzle_hash(get_bytes32(4), "xch") + async def get_next_address(self, request: GetNextAddress) -> GetNextAddressResponse: + self.add_to_log("get_next_address", (request.wallet_id, request.new_address)) + if request.new_address: + return GetNextAddressResponse(request.wallet_id, encode_puzzle_hash(get_bytes32(3), "xch")) + return GetNextAddressResponse(request.wallet_id, encode_puzzle_hash(get_bytes32(4), "xch")) inst_rpc_client = GetAddressWalletRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client @@ -569,8 +575,8 @@ def test_del_unconfirmed_tx(capsys: object, get_test_cli_clients: tuple[TestRpcC # set RPC Client class UnconfirmedTxRpcClient(TestWalletRpcClient): - async def delete_unconfirmed_transactions(self, wallet_id: int) -> None: - self.add_to_log("delete_unconfirmed_transactions", (wallet_id,)) + async def delete_unconfirmed_transactions(self, request: DeleteUnconfirmedTransactions) -> None: + self.add_to_log("delete_unconfirmed_transactions", (request.wallet_id,)) inst_rpc_client = UnconfirmedTxRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client @@ -594,9 +600,9 @@ def test_get_derivation_index(capsys: object, get_test_cli_clients: tuple[TestRp # set RPC Client class GetDerivationIndexRpcClient(TestWalletRpcClient): - async def get_current_derivation_index(self) -> str: + async def get_current_derivation_index(self) -> GetCurrentDerivationIndexResponse: self.add_to_log("get_current_derivation_index", ()) - return str(520) + return GetCurrentDerivationIndexResponse(uint32(520)) inst_rpc_client = GetDerivationIndexRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client @@ -625,8 +631,9 @@ def test_sign_message(capsys: object, get_test_cli_clients: tuple[TestRpcClients # these are various things that should be in the output assert_list = [ f"Message: {message.hex()}", - f"Public Key: {bytes([3] * 48).hex()}", - f"Signature: {bytes([6] * 576).hex()}", + "Public Key: b5acf3599bc5fa5da1c00f6cc3d5bcf1560def67778b7f50a8c373a83f78761505b6250ab776e38a292e26628009aec4", + "Signature: c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", f"Signing Mode: {SigningMode.CHIP_0002.value}", ] run_cli_command_and_assert(capsys, root_dir, [*command_args, f"-a{xch_addr}"], assert_list) @@ -641,9 +648,9 @@ def test_update_derivation_index(capsys: object, get_test_cli_clients: tuple[Tes # set RPC Client class UpdateDerivationIndexRpcClient(TestWalletRpcClient): - async def extend_derivation_index(self, index: int) -> str: - self.add_to_log("extend_derivation_index", (index,)) - return str(index) + async def extend_derivation_index(self, request: ExtendDerivationIndex) -> ExtendDerivationIndexResponse: + self.add_to_log("extend_derivation_index", (request.index,)) + return ExtendDerivationIndexResponse(request.index) inst_rpc_client = UpdateDerivationIndexRpcClient() test_rpc_clients.wallet_rpc_client = inst_rpc_client diff --git a/chia/_tests/pools/test_pool_rpc.py b/chia/_tests/pools/test_pool_rpc.py index 08883d41b947..ee99e981009f 100644 --- a/chia/_tests/pools/test_pool_rpc.py +++ b/chia/_tests/pools/test_pool_rpc.py @@ -41,6 +41,7 @@ from chia.wallet.util.wallet_types import WalletType from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_request_types import ( + DeleteUnconfirmedTransactions, GetTransactions, GetWalletBalance, GetWallets, @@ -463,7 +464,7 @@ async def pw_created(check_wallet_id: int) -> bool: def mempool_empty() -> bool: return full_node_api.full_node.mempool_manager.mempool.size() == 0 - await client.delete_unconfirmed_transactions(1) + await client.delete_unconfirmed_transactions(DeleteUnconfirmedTransactions(uint32(1))) await full_node_api.process_all_wallet_transactions(wallet=wallet) await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) diff --git a/chia/_tests/wallet/did_wallet/test_did.py b/chia/_tests/wallet/did_wallet/test_did.py index ccb12e81f786..45fc25b39cb8 100644 --- a/chia/_tests/wallet/did_wallet/test_did.py +++ b/chia/_tests/wallet/did_wallet/test_did.py @@ -23,6 +23,7 @@ from chia.types.peer_info import PeerInfo from chia.types.signing_mode import CHIP_0002_SIGN_MESSAGE_PREFIX from chia.util.bech32m import decode_puzzle_hash, encode_puzzle_hash +from chia.util.byte_types import hexstr_to_bytes from chia.wallet.did_wallet.did_wallet import DIDWallet from chia.wallet.singleton import ( create_singleton_puzzle, @@ -1091,9 +1092,9 @@ async def test_did_sign_message(wallet_environments: WalletTestFramework): ) puzzle: Program = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, message)) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test hex string message = "0123456789ABCDEF" @@ -1107,9 +1108,9 @@ async def test_did_sign_message(wallet_environments: WalletTestFramework): puzzle = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, bytes.fromhex(message))) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test BLS sign string @@ -1119,15 +1120,15 @@ async def test_did_sign_message(wallet_environments: WalletTestFramework): { "id": encode_puzzle_hash(did_wallet_1.did_info.origin_coin.name(), AddressType.DID.value), "message": message, - "is_hex": "False", - "safe_mode": "False", + "is_hex": False, + "safe_mode": False, } ) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), bytes(message, "utf-8"), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test BLS sign hex message = "0123456789ABCDEF" @@ -1142,9 +1143,9 @@ async def test_did_sign_message(wallet_environments: WalletTestFramework): ) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), - bytes.fromhex(message), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), + hexstr_to_bytes(message), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) diff --git a/chia/_tests/wallet/nft_wallet/test_nft_wallet.py b/chia/_tests/wallet/nft_wallet/test_nft_wallet.py index 0e81cb9f6e41..6bdb43e48e48 100644 --- a/chia/_tests/wallet/nft_wallet/test_nft_wallet.py +++ b/chia/_tests/wallet/nft_wallet/test_nft_wallet.py @@ -5,7 +5,7 @@ from typing import Any, Callable import pytest -from chia_rs import AugSchemeMPL, G1Element, G2Element +from chia_rs import AugSchemeMPL from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import uint16, uint32, uint64 from clvm_tools.binutils import disassemble @@ -41,6 +41,7 @@ NFTTransferBulk, NFTTransferNFT, NFTWalletWithDID, + SignMessageByID, ) from chia.wallet.wallet_rpc_api import MAX_NFT_CHUNK_SIZE from chia.wallet.wallet_state_manager import WalletStateManager @@ -2757,51 +2758,55 @@ async def test_nft_sign_message(wallet_environments: WalletTestFramework) -> Non assert not coin.pending_transaction # Test general string message = "Hello World" - pubkey, sig, _ = await env.rpc_client.sign_message_by_id( - id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), message=message + sign_by_id_res = await env.rpc_client.sign_message_by_id( + SignMessageByID(id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), message=message) ) puzzle = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, message)) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(pubkey)), + sign_by_id_res.pubkey, puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(sig)), + sign_by_id_res.signature, ) # Test hex string message = "0123456789ABCDEF" - pubkey, sig, _ = await env.rpc_client.sign_message_by_id( - id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), message=message, is_hex=True + sign_by_id_res = await env.rpc_client.sign_message_by_id( + SignMessageByID(id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), message=message, is_hex=True) ) puzzle = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, bytes.fromhex(message))) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(pubkey)), + sign_by_id_res.pubkey, puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(sig)), + sign_by_id_res.signature, ) # Test BLS sign string message = "Hello World" - pubkey, sig, _ = await env.rpc_client.sign_message_by_id( - id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), - message=message, - is_hex=False, - safe_mode=False, + sign_by_id_res = await env.rpc_client.sign_message_by_id( + SignMessageByID( + id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), + message=message, + is_hex=False, + safe_mode=False, + ) ) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(pubkey)), + sign_by_id_res.pubkey, bytes(message, "utf-8"), - G2Element.from_bytes(bytes.fromhex(sig)), + sign_by_id_res.signature, ) # Test BLS sign hex message = "0123456789ABCDEF" - pubkey, sig, _ = await env.rpc_client.sign_message_by_id( - id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), - message=message, - is_hex=True, - safe_mode=False, + sign_by_id_res = await env.rpc_client.sign_message_by_id( + SignMessageByID( + id=encode_puzzle_hash(coin.launcher_id, AddressType.NFT.value), + message=message, + is_hex=True, + safe_mode=False, + ) ) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(pubkey)), + sign_by_id_res.pubkey, bytes.fromhex(message), - G2Element.from_bytes(bytes.fromhex(sig)), + sign_by_id_res.signature, ) diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index d714fc47a03b..d3936307cb22 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -108,6 +108,8 @@ CombineCoins, DefaultCAT, DeleteKey, + DeleteNotifications, + DeleteUnconfirmedTransactions, DIDCreateBackupFile, DIDGetDID, DIDGetMetadata, @@ -118,11 +120,13 @@ DIDTransferDID, DIDUpdateMetadata, FungibleAsset, + GetNextAddress, GetNotifications, GetPrivateKey, GetSyncStatusResponse, GetTimestampForHeight, GetTransaction, + GetTransactionCount, GetTransactions, GetWalletBalance, GetWalletBalances, @@ -194,11 +198,11 @@ async def farm_transaction( async def generate_funds(full_node_api: FullNodeSimulator, wallet_bundle: WalletBundle, num_blocks: int = 1) -> int: - wallet_id = 1 - initial_balances = ( - await wallet_bundle.rpc_client.get_wallet_balance(GetWalletBalance(uint32(wallet_id))) - ).wallet_balance - ph: bytes32 = decode_puzzle_hash(await wallet_bundle.rpc_client.get_next_address(wallet_id, True)) + wallet_id = uint32(1) + initial_balances = (await wallet_bundle.rpc_client.get_wallet_balance(GetWalletBalance(wallet_id))).wallet_balance + ph: bytes32 = decode_puzzle_hash( + (await wallet_bundle.rpc_client.get_next_address(GetNextAddress(wallet_id, True))).address + ) generated_funds = 0 for _ in range(num_blocks): await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) @@ -1096,14 +1100,16 @@ async def test_get_transaction_count(wallet_rpc_environment: WalletRpcTestEnviro all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert len(all_transactions) > 0 - transaction_count = await client.get_transaction_count(1) - assert transaction_count == len(all_transactions) - transaction_count = await client.get_transaction_count(1, confirmed=False) - assert transaction_count == 0 - transaction_count = await client.get_transaction_count( - 1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]) + transaction_count_response = await client.get_transaction_count(GetTransactionCount(uint32(1))) + assert transaction_count_response.count == len(all_transactions) + transaction_count_response = await client.get_transaction_count(GetTransactionCount(uint32(1), confirmed=False)) + assert transaction_count_response.count == 0 + transaction_count_response = await client.get_transaction_count( + GetTransactionCount( + uint32(1), type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]) + ) ) - assert transaction_count == 0 + assert transaction_count_response.count == 0 @pytest.mark.parametrize( @@ -1210,8 +1216,8 @@ async def test_cat_endpoints(wallet_environments: WalletTestFramework, wallet_ty ] ) - addr_0 = await env_0.rpc_client.get_next_address(cat_0_id, False) - addr_1 = await env_1.rpc_client.get_next_address(cat_1_id, False) + addr_0 = (await env_0.rpc_client.get_next_address(GetNextAddress(cat_0_id, False))).address + addr_1 = (await env_1.rpc_client.get_next_address(GetNextAddress(cat_1_id, False))).address assert addr_0 != addr_1 @@ -1417,7 +1423,7 @@ async def test_offer_endpoints(wallet_environments: WalletTestFramework, wallet_ # Creates a wallet for the same CAT on wallet_2 and send 4 CAT from wallet_1 to it await env_2.rpc_client.create_wallet_for_existing_cat(cat_asset_id) - wallet_2_address = await env_2.rpc_client.get_next_address(cat_wallet_id, False) + wallet_2_address = (await env_2.rpc_client.get_next_address(GetNextAddress(cat_wallet_id, False))).address adds = [{"puzzle_hash": decode_puzzle_hash(wallet_2_address), "amount": uint64(4), "memos": ["the cat memo"]}] tx_res = ( await env_1.rpc_client.send_transaction_multi( @@ -2146,7 +2152,7 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn wallet_node: WalletNode = env.wallet_1.node client: WalletRpcClient = env.wallet_1.rpc_client - address = await client.get_next_address(1, True) + address = (await client.get_next_address(GetNextAddress(uint32(1), True))).address assert len(address) > 10 pks = (await client.get_public_keys()).pk_fingerprints @@ -2165,7 +2171,7 @@ async def test_key_and_address_endpoints(wallet_rpc_environment: WalletRpcTestEn await time_out_assert(20, tx_in_mempool, True, client, created_tx.name) assert len(await wallet.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(1)) == 1 - await client.delete_unconfirmed_transactions(1) + await client.delete_unconfirmed_transactions(DeleteUnconfirmedTransactions(uint32(1))) assert len(await wallet.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(1)) == 0 sk_resp = await client.get_private_key(GetPrivateKey(pks[0])) @@ -2554,7 +2560,7 @@ async def test_notification_rpcs(wallet_rpc_environment: WalletRpcTestEnvironmen assert [notification] == (await client_2.get_notifications(GetNotifications(None, None, uint32(1)))).notifications assert [] == (await client_2.get_notifications(GetNotifications(None, uint32(1), None))).notifications assert [notification] == (await client_2.get_notifications(GetNotifications(None, None, None))).notifications - assert await client_2.delete_notifications() + await client_2.delete_notifications(DeleteNotifications()) assert [] == (await client_2.get_notifications(GetNotifications([notification.id]))).notifications async with wallet_2.wallet_state_manager.new_action_scope(DEFAULT_TX_CONFIG, push=True) as action_scope: @@ -2576,7 +2582,7 @@ async def test_notification_rpcs(wallet_rpc_environment: WalletRpcTestEnvironmen await time_out_assert(20, env.wallet_2.wallet.get_confirmed_balance, uint64(200000000000)) notification = (await client_2.get_notifications(GetNotifications())).notifications[0] - assert await client_2.delete_notifications([notification.id]) + await client_2.delete_notifications(DeleteNotifications([notification.id])) assert [] == (await client_2.get_notifications(GetNotifications([notification.id]))).notifications @@ -2790,7 +2796,7 @@ async def test_set_wallet_resync_on_startup(wallet_rpc_environment: WalletRpcTes nft_wallet = await wc.create_new_nft_wallet(None) nft_wallet_id = nft_wallet["wallet_id"] - address = await wc.get_next_address(env.wallet_1.wallet.id(), True) + address = (await wc.get_next_address(GetNextAddress(env.wallet_1.wallet.id(), True))).address await wc.mint_nft( request=NFTMintNFTRequest( wallet_id=nft_wallet_id, diff --git a/chia/_tests/wallet/test_wallet.py b/chia/_tests/wallet/test_wallet.py index a24258f22559..1f1662953276 100644 --- a/chia/_tests/wallet/test_wallet.py +++ b/chia/_tests/wallet/test_wallet.py @@ -19,6 +19,7 @@ from chia.types.peer_info import PeerInfo from chia.types.signing_mode import CHIP_0002_SIGN_MESSAGE_PREFIX from chia.util.bech32m import encode_puzzle_hash +from chia.util.byte_types import hexstr_to_bytes from chia.util.errors import Err from chia.wallet.conditions import ConditionValidTimes from chia.wallet.derive_keys import master_sk_to_wallet_sk @@ -2008,9 +2009,9 @@ async def test_sign_message(self, wallet_environments: WalletTestFramework) -> N puzzle: Program = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, message)) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test hex string message = "0123456789ABCDEF" @@ -2020,9 +2021,9 @@ async def test_sign_message(self, wallet_environments: WalletTestFramework) -> N puzzle = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, bytes.fromhex(message))) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test informal input message = "0123456789ABCDEF" @@ -2032,9 +2033,9 @@ async def test_sign_message(self, wallet_environments: WalletTestFramework) -> N puzzle = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, bytes.fromhex(message))) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), puzzle.get_tree_hash(), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test BLS sign string message = "Hello World" @@ -2043,9 +2044,9 @@ async def test_sign_message(self, wallet_environments: WalletTestFramework) -> N ) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), bytes(message, "utf-8"), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) # Test BLS sign hex message = "0123456789ABCDEF" @@ -2054,9 +2055,9 @@ async def test_sign_message(self, wallet_environments: WalletTestFramework) -> N ) assert AugSchemeMPL.verify( - G1Element.from_bytes(bytes.fromhex(response["pubkey"])), - bytes.fromhex(message), - G2Element.from_bytes(bytes.fromhex(response["signature"])), + G1Element.from_bytes(hexstr_to_bytes(response["pubkey"])), + hexstr_to_bytes(message), + G2Element.from_bytes(hexstr_to_bytes(response["signature"])), ) @pytest.mark.parametrize( diff --git a/chia/_tests/wallet/test_wallet_state_manager.py b/chia/_tests/wallet/test_wallet_state_manager.py index 0429e6090dba..1289f37cca53 100644 --- a/chia/_tests/wallet/test_wallet_state_manager.py +++ b/chia/_tests/wallet/test_wallet_state_manager.py @@ -21,7 +21,7 @@ from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.transaction_type import TransactionType from chia.wallet.util.wallet_types import WalletType -from chia.wallet.wallet_request_types import PushTransactions +from chia.wallet.wallet_request_types import ExtendDerivationIndex, PushTransactions from chia.wallet.wallet_rpc_api import MAX_DERIVATION_INDEX_DELTA from chia.wallet.wallet_spend_bundle import WalletSpendBundle from chia.wallet.wallet_state_manager import WalletStateManager @@ -286,6 +286,11 @@ async def get_puzzle_hash_state() -> PuzzleHashState: expected_state = await get_puzzle_hash_state() + # Quick test of this RPC + assert ( + await wallet_environments.environments[0].rpc_client.get_current_derivation_index() + ).index == expected_state.highest_index + # `create_more_puzzle_hashes` # No-op result = await wsm.create_more_puzzle_hashes() @@ -420,7 +425,7 @@ async def get_puzzle_hash_state() -> PuzzleHashState: (0,), ) with pytest.raises(ValueError): - await rpc_client.extend_derivation_index(0) + await rpc_client.extend_derivation_index(ExtendDerivationIndex(uint32(0))) # Reset to a normal state await wsm.puzzle_store.delete_wallet(wsm.main_wallet.id()) @@ -431,15 +436,17 @@ async def get_puzzle_hash_state() -> PuzzleHashState: # Test an index already created with pytest.raises(ValueError): - await rpc_client.extend_derivation_index(0) + await rpc_client.extend_derivation_index(ExtendDerivationIndex(uint32(0))) # Test an index too far in the future with pytest.raises(ValueError): - await rpc_client.extend_derivation_index(MAX_DERIVATION_INDEX_DELTA + expected_state.highest_index + 1) + await rpc_client.extend_derivation_index( + ExtendDerivationIndex(uint32(MAX_DERIVATION_INDEX_DELTA + expected_state.highest_index + 1)) + ) # Test the actual functionality - assert await rpc_client.extend_derivation_index(expected_state.highest_index + 5) == str( - expected_state.highest_index + 5 - ) + assert ( + await rpc_client.extend_derivation_index(ExtendDerivationIndex(uint32(expected_state.highest_index + 5))) + ).index == expected_state.highest_index + 5 expected_state = PuzzleHashState(expected_state.highest_index + 5, expected_state.used_up_to_index) assert await get_puzzle_hash_state() == expected_state diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index a4cf58ea5afa..50c4f17c0ef3 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -45,6 +45,8 @@ from chia.wallet.wallet_coin_store import GetCoinRecords from chia.wallet.wallet_request_types import ( CATSpendResponse, + DeleteNotifications, + DeleteUnconfirmedTransactions, DIDFindLostDID, DIDGetDID, DIDGetInfo, @@ -52,7 +54,9 @@ DIDSetWalletName, DIDTransferDID, DIDUpdateMetadata, + ExtendDerivationIndex, FungibleAsset, + GetNextAddress, GetNotifications, GetTransaction, GetTransactions, @@ -69,6 +73,10 @@ NFTTransferNFT, RoyaltyAsset, SendTransactionResponse, + SignMessageByAddress, + SignMessageByAddressResponse, + SignMessageByID, + SignMessageByIDResponse, VCAddProofs, VCGet, VCGetList, @@ -418,7 +426,7 @@ async def get_address( root_path: pathlib.Path, wallet_rpc_port: Optional[int], fp: Optional[int], wallet_id: int, new_address: bool ) -> None: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, _, _): - res = await wallet_client.get_next_address(wallet_id, new_address) + res = (await wallet_client.get_next_address(GetNextAddress(uint32(wallet_id), new_address))).address print(res) @@ -426,14 +434,14 @@ async def delete_unconfirmed_transactions( root_path: pathlib.Path, wallet_rpc_port: Optional[int], fp: Optional[int], wallet_id: int ) -> None: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, fingerprint, _): - await wallet_client.delete_unconfirmed_transactions(wallet_id) + await wallet_client.delete_unconfirmed_transactions(DeleteUnconfirmedTransactions(uint32(wallet_id))) print(f"Successfully deleted all unconfirmed transactions for wallet id {wallet_id} on key {fingerprint}") async def get_derivation_index(root_path: pathlib.Path, wallet_rpc_port: Optional[int], fp: Optional[int]) -> None: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, _, _): res = await wallet_client.get_current_derivation_index() - print(f"Last derivation index: {res}") + print(f"Last derivation index: {res.index}") async def update_derivation_index( @@ -441,8 +449,8 @@ async def update_derivation_index( ) -> None: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, _, _): print("Updating derivation index... This may take a while.") - res = await wallet_client.extend_derivation_index(index) - print(f"Updated derivation index: {res}") + res = await wallet_client.extend_derivation_index(ExtendDerivationIndex(uint32(index))) + print(f"Updated derivation index: {res.index}") print("Your balances may take a while to update.") @@ -1599,9 +1607,11 @@ async def delete_notifications( ) -> None: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, _, _): if delete_all: - print(f"Success: {await wallet_client.delete_notifications()}") + await wallet_client.delete_notifications(DeleteNotifications()) + print("Success!") else: - print(f"Success: {await wallet_client.delete_notifications(ids=list(ids))}") + await wallet_client.delete_notifications(DeleteNotifications(ids=list(ids))) + print("Success!") async def sign_message( @@ -1616,31 +1626,32 @@ async def sign_message( nft_id: Optional[CliAddress] = None, ) -> None: async with get_wallet_client(root_path, wallet_rpc_port, fp) as (wallet_client, _, _): + response: Union[SignMessageByAddressResponse, SignMessageByIDResponse] if addr_type == AddressType.XCH: if address is None: print("Address is required for XCH address type.") return - pubkey, signature, signing_mode = await wallet_client.sign_message_by_address( - address.original_address, message + response = await wallet_client.sign_message_by_address( + SignMessageByAddress(address.original_address, message) ) elif addr_type == AddressType.DID: if did_id is None: print("DID id is required for DID address type.") return - pubkey, signature, signing_mode = await wallet_client.sign_message_by_id(did_id.original_address, message) + response = await wallet_client.sign_message_by_id(SignMessageByID(did_id.original_address, message)) elif addr_type == AddressType.NFT: if nft_id is None: print("NFT id is required for NFT address type.") return - pubkey, signature, signing_mode = await wallet_client.sign_message_by_id(nft_id.original_address, message) + response = await wallet_client.sign_message_by_id(SignMessageByID(nft_id.original_address, message)) else: print("Invalid wallet type.") return print("") print(f"Message: {message}") - print(f"Public Key: {pubkey}") - print(f"Signature: {signature}") - print(f"Signing Mode: {signing_mode}") + print(f"Public Key: {response.pubkey!s}") + print(f"Signature: {response.signature!s}") + print(f"Signing Mode: {response.signing_mode}") async def spend_clawback( diff --git a/chia/wallet/wallet_request_types.py b/chia/wallet/wallet_request_types.py index 10c56d1b78b8..74b616a34bff 100644 --- a/chia/wallet/wallet_request_types.py +++ b/chia/wallet/wallet_request_types.py @@ -353,6 +353,12 @@ class GetNotificationsResponse(Streamable): notifications: list[Notification] +@streamable +@dataclass(frozen=True) +class DeleteNotifications(Streamable): + ids: Optional[list[bytes32]] = None + + @streamable @dataclass(frozen=True) class VerifySignature(Streamable): @@ -370,6 +376,41 @@ class VerifySignatureResponse(Streamable): error: Optional[str] = None +@streamable +@dataclass(frozen=True) +class SignMessageByAddress(Streamable): + address: str + message: str + is_hex: bool = False + safe_mode: bool = True + + +@streamable +@dataclass(frozen=True) +class SignMessageByAddressResponse(Streamable): + pubkey: G1Element + signature: G2Element + signing_mode: str + + +@streamable +@dataclass(frozen=True) +class SignMessageByID(Streamable): + id: str + message: str + is_hex: bool = False + safe_mode: bool = True + + +@streamable +@dataclass(frozen=True) +class SignMessageByIDResponse(Streamable): + pubkey: G1Element + signature: G2Element + latest_coin_id: bytes32 + signing_mode: str + + @streamable @dataclass(frozen=True) class GetTransactionMemo(Streamable): @@ -402,6 +443,60 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> GetTransactionMemoResponse ) +@streamable +@dataclass(frozen=True) +class GetTransactionCount(Streamable): + wallet_id: uint32 + confirmed: Optional[bool] = None + type_filter: Optional[TransactionTypeFilter] = None + + +@streamable +@dataclass(frozen=True) +class GetTransactionCountResponse(Streamable): + wallet_id: uint32 + count: uint16 + + +@streamable +@dataclass(frozen=True) +class GetNextAddress(Streamable): + wallet_id: uint32 + new_address: bool = False + save_derivations: bool = True + + +@streamable +@dataclass(frozen=True) +class GetNextAddressResponse(Streamable): + wallet_id: uint32 + address: str + + +@streamable +@dataclass(frozen=True) +class DeleteUnconfirmedTransactions(Streamable): + wallet_id: uint32 + + +@streamable +@dataclass(frozen=True) +class GetCurrentDerivationIndexResponse(Streamable): + index: Optional[uint32] + + +@streamable +@dataclass(frozen=True) +class ExtendDerivationIndex(Streamable): + index: uint32 + + +@streamable +@dataclass(frozen=True) +class ExtendDerivationIndexResponse(Streamable): + index: Optional[uint32] + + @streamable @dataclass(frozen=True) class GetOffersCountResponse(Streamable): diff --git a/chia/wallet/wallet_rpc_api.py b/chia/wallet/wallet_rpc_api.py index a7955f7a1f4a..3c687a8f6ea4 100644 --- a/chia/wallet/wallet_rpc_api.py +++ b/chia/wallet/wallet_rpc_api.py @@ -90,7 +90,7 @@ from chia.wallet.util.compute_hints import compute_spend_hints_and_additions from chia.wallet.util.compute_memos import compute_memos from chia.wallet.util.curry_and_treehash import NIL_TREEHASH -from chia.wallet.util.query_filter import FilterMode, HashFilter, TransactionTypeFilter +from chia.wallet.util.query_filter import FilterMode, HashFilter from chia.wallet.util.transaction_type import CLAWBACK_INCOMING_TRANSACTION_TYPES, TransactionType from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG, TXConfig, TXConfigLoader from chia.wallet.util.wallet_sync_utils import fetch_coin_spend_for_coin_state @@ -119,6 +119,8 @@ CreateNewDL, CreateNewDLResponse, DeleteKey, + DeleteNotifications, + DeleteUnconfirmedTransactions, DIDCreateBackupFile, DIDCreateBackupFileResponse, DIDFindLostDID, @@ -165,11 +167,16 @@ Empty, ExecuteSigningInstructions, ExecuteSigningInstructionsResponse, + ExtendDerivationIndex, + ExtendDerivationIndexResponse, GatherSigningInfo, GatherSigningInfoResponse, GenerateMnemonicResponse, + GetCurrentDerivationIndexResponse, GetHeightInfoResponse, GetLoggedInFingerprintResponse, + GetNextAddress, + GetNextAddressResponse, GetNotifications, GetNotificationsResponse, GetPrivateKey, @@ -180,6 +187,8 @@ GetTimestampForHeight, GetTimestampForHeightResponse, GetTransaction, + GetTransactionCount, + GetTransactionCountResponse, GetTransactionMemo, GetTransactionMemoResponse, GetTransactionResponse, @@ -234,6 +243,10 @@ PWStatus, PWStatusResponse, SetWalletResyncOnStartup, + SignMessageByAddress, + SignMessageByAddressResponse, + SignMessageByID, + SignMessageByIDResponse, SplitCoins, SplitCoinsResponse, SubmitTransactions, @@ -255,6 +268,8 @@ VCRevokeResponse, VCSpend, VCSpendResponse, + VerifySignature, + VerifySignatureResponse, WalletInfoResponse, ) from chia.wallet.wallet_spend_bundle import WalletSpendBundle @@ -1538,46 +1553,39 @@ async def get_transactions(self, request: GetTransactions) -> GetTransactionsRes wallet_id=request.wallet_id, ) - async def get_transaction_count(self, request: dict[str, Any]) -> EndpointResult: - wallet_id = int(request["wallet_id"]) - type_filter = None - if "type_filter" in request: - type_filter = TransactionTypeFilter.from_json_dict(request["type_filter"]) + @marshal + async def get_transaction_count(self, request: GetTransactionCount) -> GetTransactionCountResponse: count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet( - wallet_id, confirmed=request.get("confirmed", None), type_filter=type_filter + request.wallet_id, confirmed=request.confirmed, type_filter=request.type_filter + ) + return GetTransactionCountResponse( + request.wallet_id, + uint16(count), ) - return { - "count": count, - "wallet_id": wallet_id, - } - async def get_next_address(self, request: dict[str, Any]) -> EndpointResult: + @marshal + async def get_next_address(self, request: GetNextAddress) -> GetNextAddressResponse: """ Returns a new address """ - if request["new_address"] is True: - create_new = True - else: - create_new = False - wallet_id = uint32(request["wallet_id"]) - wallet = self.service.wallet_state_manager.wallets[wallet_id] + wallet = self.service.wallet_state_manager.wallets[request.wallet_id] selected = self.service.config["selected_network"] prefix = self.service.config["network_overrides"]["config"][selected]["address_prefix"] if wallet.type() in {WalletType.STANDARD_WALLET, WalletType.CAT, WalletType.CRCAT, WalletType.RCAT}: async with self.service.wallet_state_manager.new_action_scope( - DEFAULT_TX_CONFIG, push=request.get("save_derivations", True) + DEFAULT_TX_CONFIG, push=request.save_derivations ) as action_scope: raw_puzzle_hash = await action_scope.get_puzzle_hash( - self.service.wallet_state_manager, override_reuse_puzhash_with=not create_new + self.service.wallet_state_manager, override_reuse_puzhash_with=not request.new_address ) address = encode_puzzle_hash(raw_puzzle_hash, prefix) else: raise ValueError(f"Wallet type {wallet.type()} cannot create puzzle hashes") - return { - "wallet_id": wallet_id, - "address": address, - } + return GetNextAddressResponse( + request.wallet_id, + address, + ) @tx_endpoint(push=True) async def send_transaction( @@ -1717,20 +1725,20 @@ async def spend_clawback_coins( "transactions": None, # tx_endpoint wrapper will take care of this } - async def delete_unconfirmed_transactions(self, request: dict[str, Any]) -> EndpointResult: - wallet_id = uint32(request["wallet_id"]) - if wallet_id not in self.service.wallet_state_manager.wallets: - raise ValueError(f"Wallet id {wallet_id} does not exist") + @marshal + async def delete_unconfirmed_transactions(self, request: DeleteUnconfirmedTransactions) -> Empty: + if request.wallet_id not in self.service.wallet_state_manager.wallets: + raise ValueError(f"Wallet id {request.wallet_id} does not exist") if await self.service.wallet_state_manager.synced() is False: raise ValueError("Wallet needs to be fully synced.") async with self.service.wallet_state_manager.db_wrapper.writer(): - await self.service.wallet_state_manager.tx_store.delete_unconfirmed_transactions(wallet_id) - wallet = self.service.wallet_state_manager.wallets[wallet_id] + await self.service.wallet_state_manager.tx_store.delete_unconfirmed_transactions(request.wallet_id) + wallet = self.service.wallet_state_manager.wallets[request.wallet_id] if wallet.type() == WalletType.POOLING_WALLET.value: assert isinstance(wallet, PoolWallet) wallet.target_state = None - return {} + return Empty() async def select_coins( self, @@ -1872,26 +1880,23 @@ async def get_coin_records_by_names(self, request: dict[str, Any]) -> EndpointRe return {"coin_records": [cr.to_json_dict() for cr in coin_records]} - async def get_current_derivation_index(self, request: dict[str, Any]) -> dict[str, Any]: + @marshal + async def get_current_derivation_index(self, request: Empty) -> GetCurrentDerivationIndexResponse: assert self.service.wallet_state_manager is not None index: Optional[uint32] = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path() - return {"success": True, "index": index} + return GetCurrentDerivationIndexResponse(index) - async def extend_derivation_index(self, request: dict[str, Any]) -> dict[str, Any]: + @marshal + async def extend_derivation_index(self, request: ExtendDerivationIndex) -> ExtendDerivationIndexResponse: assert self.service.wallet_state_manager is not None - # Require a new max derivation index - if "index" not in request: - raise ValueError("Derivation index is required") - # Require that the wallet is fully synced synced = await self.service.wallet_state_manager.synced() if synced is False: raise ValueError("Wallet needs to be fully synced before extending derivation index") - index = uint32(request["index"]) current: Optional[uint32] = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path() # Additional sanity check that the wallet is synced @@ -1899,10 +1904,10 @@ async def extend_derivation_index(self, request: dict[str, Any]) -> dict[str, An raise ValueError("No current derivation record found, unable to extend index") # Require that the new index is greater than the current index - if index <= current: + if request.index <= current: raise ValueError(f"New derivation index must be greater than current index: {current}") - if index - current > MAX_DERIVATION_INDEX_DELTA: + if request.index - current > MAX_DERIVATION_INDEX_DELTA: raise ValueError( "Too many derivations requested. " f"Use a derivation index less than {current + MAX_DERIVATION_INDEX_DELTA + 1}" @@ -1912,14 +1917,13 @@ async def extend_derivation_index(self, request: dict[str, Any]) -> dict[str, An # to preserve the current last used index, so we call create_more_puzzle_hashes with # mark_existing_as_used=False result = await self.service.wallet_state_manager.create_more_puzzle_hashes( - from_zero=False, mark_existing_as_used=False, up_to_index=index, num_additional_phs=0 + from_zero=False, mark_existing_as_used=False, up_to_index=request.index, num_additional_phs=0 ) await result.commit(self.service.wallet_state_manager) - updated: Optional[uint32] = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path() - updated_index = updated if updated is not None else None + updated_index = await self.service.wallet_state_manager.puzzle_store.get_last_derivation_path() - return {"success": True, "index": updated_index} + return ExtendDerivationIndexResponse(updated_index) @marshal async def get_notifications(self, request: GetNotifications) -> GetNotificationsResponse: @@ -1938,16 +1942,16 @@ async def get_notifications(self, request: GetNotifications) -> GetNotifications return GetNotificationsResponse(notifications) - async def delete_notifications(self, request: dict[str, Any]) -> EndpointResult: - ids: Optional[list[str]] = request.get("ids", None) - if ids is None: + @marshal + async def delete_notifications(self, request: DeleteNotifications) -> Empty: + if request.ids is None: await self.service.wallet_state_manager.notification_manager.notification_store.delete_all_notifications() else: await self.service.wallet_state_manager.notification_manager.notification_store.delete_notifications( - [bytes32.from_hexstr(id) for id in ids] + request.ids ) - return {} + return Empty() @tx_endpoint(push=True) async def send_notification( @@ -1967,116 +1971,101 @@ async def send_notification( return {"tx": None, "transactions": None} # tx_endpoint wrapper will take care of this - async def verify_signature(self, request: dict[str, Any]) -> EndpointResult: + @marshal + async def verify_signature(self, request: VerifySignature) -> VerifySignatureResponse: """ Given a public key, message and signature, verify if it is valid. :param request: :return: """ - input_message: str = request["message"] - signing_mode_str: Optional[str] = request.get("signing_mode") # Default to BLS_MESSAGE_AUGMENTATION_HEX_INPUT as this RPC was originally designed to verify # signatures made by `chia keys sign`, which uses BLS_MESSAGE_AUGMENTATION_HEX_INPUT - if signing_mode_str is None: + if request.signing_mode is None: signing_mode = SigningMode.BLS_MESSAGE_AUGMENTATION_HEX_INPUT else: try: - signing_mode = SigningMode(signing_mode_str) + signing_mode = SigningMode(request.signing_mode) except ValueError: - raise ValueError(f"Invalid signing mode: {signing_mode_str!r}") + raise ValueError(f"Invalid signing mode: {request.signing_mode!r}") if signing_mode in {SigningMode.CHIP_0002, SigningMode.CHIP_0002_P2_DELEGATED_CONDITIONS}: # CHIP-0002 message signatures are made over the tree hash of: # ("Chia Signed Message", message) - message_to_verify: bytes = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, input_message)).get_tree_hash() + message_to_verify: bytes = Program.to((CHIP_0002_SIGN_MESSAGE_PREFIX, request.message)).get_tree_hash() elif signing_mode == SigningMode.BLS_MESSAGE_AUGMENTATION_HEX_INPUT: # Message is expected to be a hex string - message_to_verify = hexstr_to_bytes(input_message) + message_to_verify = hexstr_to_bytes(request.message) elif signing_mode == SigningMode.BLS_MESSAGE_AUGMENTATION_UTF8_INPUT: # Message is expected to be a UTF-8 string - message_to_verify = bytes(input_message, "utf-8") + message_to_verify = bytes(request.message, "utf-8") else: - raise ValueError(f"Unsupported signing mode: {signing_mode_str!r}") + raise ValueError(f"Unsupported signing mode: {request.signing_mode!r}") # Verify using the BLS message augmentation scheme is_valid = AugSchemeMPL.verify( - G1Element.from_bytes(hexstr_to_bytes(request["pubkey"])), + request.pubkey, message_to_verify, - G2Element.from_bytes(hexstr_to_bytes(request["signature"])), + request.signature, ) - address = request.get("address") - if address is not None: + if request.address is not None: # For signatures made by the sign_message_by_address/sign_message_by_id # endpoints, the "address" field should contain the p2_address of the NFT/DID # that was used to sign the message. - puzzle_hash: bytes32 = decode_puzzle_hash(address) + puzzle_hash: bytes32 = decode_puzzle_hash(request.address) expected_puzzle_hash: Optional[bytes32] = None if signing_mode == SigningMode.CHIP_0002_P2_DELEGATED_CONDITIONS: - puzzle = p2_delegated_conditions.puzzle_for_pk(Program.to(hexstr_to_bytes(request["pubkey"]))) + puzzle = p2_delegated_conditions.puzzle_for_pk(Program.to(request.pubkey)) expected_puzzle_hash = bytes32(puzzle.get_tree_hash()) else: - expected_puzzle_hash = puzzle_hash_for_synthetic_public_key( - G1Element.from_bytes(hexstr_to_bytes(request["pubkey"])) - ) + expected_puzzle_hash = puzzle_hash_for_synthetic_public_key(request.pubkey) if puzzle_hash != expected_puzzle_hash: - return {"isValid": False, "error": "Public key doesn't match the address"} + return VerifySignatureResponse(isValid=False, error="Public key doesn't match the address") if is_valid: - return {"isValid": is_valid} + return VerifySignatureResponse(isValid=is_valid) else: - return {"isValid": False, "error": "Signature is invalid."} + return VerifySignatureResponse(isValid=False, error="Signature is invalid.") - async def sign_message_by_address(self, request: dict[str, Any]) -> EndpointResult: + @marshal + async def sign_message_by_address(self, request: SignMessageByAddress) -> SignMessageByAddressResponse: """ Given a derived P2 address, sign the message by its private key. :param request: :return: """ - puzzle_hash: bytes32 = decode_puzzle_hash(request["address"]) - is_hex: bool = request.get("is_hex", False) - if isinstance(is_hex, str): - is_hex = True if is_hex.lower() == "true" else False - safe_mode: bool = request.get("safe_mode", True) - if isinstance(safe_mode, str): - safe_mode = True if safe_mode.lower() == "true" else False + puzzle_hash: bytes32 = decode_puzzle_hash(request.address) mode: SigningMode = SigningMode.CHIP_0002 - if is_hex and safe_mode: + if request.is_hex and request.safe_mode: mode = SigningMode.CHIP_0002_HEX_INPUT - elif not is_hex and not safe_mode: + elif not request.is_hex and not request.safe_mode: mode = SigningMode.BLS_MESSAGE_AUGMENTATION_UTF8_INPUT - elif is_hex and not safe_mode: + elif request.is_hex and not request.safe_mode: mode = SigningMode.BLS_MESSAGE_AUGMENTATION_HEX_INPUT pubkey, signature = await self.service.wallet_state_manager.main_wallet.sign_message( - request["message"], puzzle_hash, mode + request.message, puzzle_hash, mode + ) + return SignMessageByAddressResponse( + pubkey=pubkey, + signature=signature, + signing_mode=mode.value, ) - return { - "success": True, - "pubkey": str(pubkey), - "signature": str(signature), - "signing_mode": mode.value, - } - async def sign_message_by_id(self, request: dict[str, Any]) -> EndpointResult: + @marshal + async def sign_message_by_id(self, request: SignMessageByID) -> SignMessageByIDResponse: """ Given a NFT/DID ID, sign the message by the P2 private key. :param request: :return: """ - entity_id: bytes32 = decode_puzzle_hash(request["id"]) + entity_id: bytes32 = decode_puzzle_hash(request.id) selected_wallet: Optional[WalletProtocol[Any]] = None - is_hex: bool = request.get("is_hex", False) - if isinstance(is_hex, str): - is_hex = True if is_hex.lower() == "true" else False - safe_mode: bool = request.get("safe_mode", True) - if isinstance(safe_mode, str): - safe_mode = True if safe_mode.lower() == "true" else False mode: SigningMode = SigningMode.CHIP_0002 - if is_hex and safe_mode: + if request.is_hex and request.safe_mode: mode = SigningMode.CHIP_0002_HEX_INPUT - elif not is_hex and not safe_mode: + elif not request.is_hex and not request.safe_mode: mode = SigningMode.BLS_MESSAGE_AUGMENTATION_UTF8_INPUT - elif is_hex and not safe_mode: + elif request.is_hex and not request.safe_mode: mode = SigningMode.BLS_MESSAGE_AUGMENTATION_HEX_INPUT - if is_valid_address(request["id"], {AddressType.DID}, self.service.config): + if is_valid_address(request.id, {AddressType.DID}, self.service.config): for wallet in self.service.wallet_state_manager.wallets.values(): if wallet.type() == WalletType.DECENTRALIZED_ID.value: assert isinstance(wallet, DIDWallet) @@ -2085,11 +2074,11 @@ async def sign_message_by_id(self, request: dict[str, Any]) -> EndpointResult: selected_wallet = wallet break if selected_wallet is None: - return {"success": False, "error": f"DID for {entity_id.hex()} doesn't exist."} + raise ValueError(f"DID for {entity_id.hex()} doesn't exist.") assert isinstance(selected_wallet, DIDWallet) - pubkey, signature = await selected_wallet.sign_message(request["message"], mode) + pubkey, signature = await selected_wallet.sign_message(request.message, mode) latest_coin_id = (await selected_wallet.get_coin()).name() - elif is_valid_address(request["id"], {AddressType.NFT}, self.service.config): + elif is_valid_address(request.id, {AddressType.NFT}, self.service.config): target_nft: Optional[NFTCoinInfo] = None for wallet in self.service.wallet_state_manager.wallets.values(): if wallet.type() == WalletType.NFT.value: @@ -2100,21 +2089,20 @@ async def sign_message_by_id(self, request: dict[str, Any]) -> EndpointResult: target_nft = nft break if selected_wallet is None or target_nft is None: - return {"success": False, "error": f"NFT for {entity_id.hex()} doesn't exist."} + raise ValueError(f"NFT for {entity_id.hex()} doesn't exist.") assert isinstance(selected_wallet, NFTWallet) - pubkey, signature = await selected_wallet.sign_message(request["message"], target_nft, mode) + pubkey, signature = await selected_wallet.sign_message(request.message, target_nft, mode) latest_coin_id = target_nft.coin.name() else: - return {"success": False, "error": f"Unknown ID type, {request['id']}"} + raise ValueError(f"Unknown ID type, {request.id}") - return { - "success": True, - "pubkey": str(pubkey), - "signature": str(signature), - "latest_coin_id": latest_coin_id.hex() if latest_coin_id is not None else None, - "signing_mode": mode.value, - } + return SignMessageByIDResponse( + pubkey=pubkey, + signature=signature, + latest_coin_id=latest_coin_id, + signing_mode=mode.value, + ) ########################################################################################## # CATs and Trading diff --git a/chia/wallet/wallet_rpc_client.py b/chia/wallet/wallet_rpc_client.py index d884302523ed..632c5e919ca0 100644 --- a/chia/wallet/wallet_rpc_client.py +++ b/chia/wallet/wallet_rpc_client.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Sequence from typing import Any, Optional, Union, cast from chia_rs.sized_bytes import bytes32 @@ -17,7 +16,6 @@ from chia.wallet.trading.offer import Offer from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable -from chia.wallet.util.query_filter import TransactionTypeFilter from chia.wallet.util.tx_config import CoinSelectionConfig, TXConfig from chia.wallet.wallet_coin_store import GetCoinRecords from chia.wallet.wallet_request_types import ( @@ -37,6 +35,8 @@ CreateOfferForIDsResponse, CreateSignedTransactionsResponse, DeleteKey, + DeleteNotifications, + DeleteUnconfirmedTransactions, DIDCreateBackupFile, DIDCreateBackupFileResponse, DIDFindLostDID, @@ -82,12 +82,17 @@ DLUpdateRootResponse, ExecuteSigningInstructions, ExecuteSigningInstructionsResponse, + ExtendDerivationIndex, + ExtendDerivationIndexResponse, GatherSigningInfo, GatherSigningInfoResponse, GenerateMnemonicResponse, GetCATListResponse, + GetCurrentDerivationIndexResponse, GetHeightInfoResponse, GetLoggedInFingerprintResponse, + GetNextAddress, + GetNextAddressResponse, GetNotifications, GetNotificationsResponse, GetOffersCountResponse, @@ -98,6 +103,8 @@ GetTimestampForHeight, GetTimestampForHeightResponse, GetTransaction, + GetTransactionCount, + GetTransactionCountResponse, GetTransactionMemo, GetTransactionMemoResponse, GetTransactionResponse, @@ -153,6 +160,10 @@ SendTransactionMultiResponse, SendTransactionResponse, SetWalletResyncOnStartup, + SignMessageByAddress, + SignMessageByAddressResponse, + SignMessageByID, + SignMessageByIDResponse, SplitCoins, SplitCoinsResponse, SubmitTransactions, @@ -275,23 +286,13 @@ async def get_transaction(self, request: GetTransaction) -> GetTransactionRespon async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: return GetTransactionsResponse.from_json_dict(await self.fetch("get_transactions", request.to_json_dict())) - async def get_transaction_count( - self, wallet_id: int, confirmed: Optional[bool] = None, type_filter: Optional[TransactionTypeFilter] = None - ) -> int: - request: dict[str, Any] = {"wallet_id": wallet_id} - if type_filter is not None: - request["type_filter"] = type_filter.to_json_dict() - if confirmed is not None: - request["confirmed"] = confirmed - res = await self.fetch("get_transaction_count", request) - # TODO: casting due to lack of type checked deserialization - return cast(int, res["count"]) + async def get_transaction_count(self, request: GetTransactionCount) -> GetTransactionCountResponse: + return GetTransactionCountResponse.from_json_dict( + await self.fetch("get_transaction_count", request.to_json_dict()) + ) - async def get_next_address(self, wallet_id: int, new_address: bool) -> str: - request = {"wallet_id": wallet_id, "new_address": new_address} - response = await self.fetch("get_next_address", request) - # TODO: casting due to lack of type checked deserialization - return cast(str, response["address"]) + async def get_next_address(self, request: GetNextAddress) -> GetNextAddressResponse: + return GetNextAddressResponse.from_json_dict(await self.fetch("get_next_address", request.to_json_dict())) async def send_transaction( self, @@ -372,18 +373,16 @@ async def spend_clawback_coins( response = await self.fetch("spend_clawback_coins", request) return response - async def delete_unconfirmed_transactions(self, wallet_id: int) -> None: - await self.fetch("delete_unconfirmed_transactions", {"wallet_id": wallet_id}) + async def delete_unconfirmed_transactions(self, request: DeleteUnconfirmedTransactions) -> None: + await self.fetch("delete_unconfirmed_transactions", request.to_json_dict()) - async def get_current_derivation_index(self) -> str: - response = await self.fetch("get_current_derivation_index", {}) - index = response["index"] - return str(index) + async def get_current_derivation_index(self) -> GetCurrentDerivationIndexResponse: + return GetCurrentDerivationIndexResponse.from_json_dict(await self.fetch("get_current_derivation_index", {})) - async def extend_derivation_index(self, index: int) -> str: - response = await self.fetch("extend_derivation_index", {"index": index}) - updated_index = response["index"] - return str(updated_index) + async def extend_derivation_index(self, request: ExtendDerivationIndex) -> ExtendDerivationIndexResponse: + return ExtendDerivationIndexResponse.from_json_dict( + await self.fetch("extend_derivation_index", request.to_json_dict()) + ) async def get_farmed_amount(self, include_pool_rewards: bool = False) -> dict[str, Any]: return await self.fetch("get_farmed_amount", {"include_pool_rewards": include_pool_rewards}) @@ -1122,14 +1121,8 @@ async def get_notifications(self, request: GetNotifications) -> GetNotifications response = await self.fetch("get_notifications", request.to_json_dict()) return json_deserialize_with_clvm_streamable(response, GetNotificationsResponse) - async def delete_notifications(self, ids: Optional[Sequence[bytes32]] = None) -> bool: - request = {} - if ids is not None: - request["ids"] = [id.hex() for id in ids] - response = await self.fetch("delete_notifications", request) - # TODO: casting due to lack of type checked deserialization - result = cast(bool, response["success"]) - return result + async def delete_notifications(self, request: DeleteNotifications) -> None: + await self.fetch("delete_notifications", request.to_json_dict()) async def send_notification( self, @@ -1155,17 +1148,13 @@ async def send_notification( ) return TransactionRecord.from_json_dict(response["tx"]) - async def sign_message_by_address(self, address: str, message: str) -> tuple[str, str, str]: - response = await self.fetch("sign_message_by_address", {"address": address, "message": message}) - return response["pubkey"], response["signature"], response["signing_mode"] - - async def sign_message_by_id( - self, id: str, message: str, is_hex: bool = False, safe_mode: bool = True - ) -> tuple[str, str, str]: - response = await self.fetch( - "sign_message_by_id", {"id": id, "message": message, "is_hex": is_hex, "safe_mode": safe_mode} + async def sign_message_by_address(self, request: SignMessageByAddress) -> SignMessageByAddressResponse: + return SignMessageByAddressResponse.from_json_dict( + await self.fetch("sign_message_by_address", request.to_json_dict()) ) - return response["pubkey"], response["signature"], response["signing_mode"] + + async def sign_message_by_id(self, request: SignMessageByID) -> SignMessageByIDResponse: + return SignMessageByIDResponse.from_json_dict(await self.fetch("sign_message_by_id", request.to_json_dict())) async def verify_signature(self, request: VerifySignature) -> VerifySignatureResponse: return VerifySignatureResponse.from_json_dict(await self.fetch("verify_signature", {**request.to_json_dict()}))