Skip to content

Commit e4990fa

Browse files
committed
SRC/API/PYTHON: fix pylint warnings in src, examples and tests
Signed-off-by: Roie Danino <[email protected]>
1 parent 91bd79f commit e4990fa

13 files changed

+233
-173
lines changed

examples/python/__init__.py

Whitespace-only changes.

examples/python/blocking_send_recv_example.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,20 @@
1616
# limitations under the License.
1717

1818
import argparse
19+
import enum
20+
import sys
1921

2022
import torch
2123

2224
from nixl._api import nixl_agent, nixl_agent_config
2325

2426

27+
class BlockingSendRecvErrCodes(enum.Enum):
28+
MEM_REG_FAILED = 1
29+
TRANSFER_FAILED = 2
30+
DATA_VERIFICATION_FAILED = 3
31+
32+
2533
def parse_args():
2634
parser = argparse.ArgumentParser()
2735
parser.add_argument("--ip", type=str, required=True)
@@ -63,7 +71,7 @@ def parse_args():
6371
reg_descs = agent.register_memory(tensors)
6472
if not reg_descs: # Same as reg_descs if successful
6573
print("Memory registration failed.")
66-
exit()
74+
sys.exit(BlockingSendRecvErrCodes.MEM_REG_FAILED.value)
6775

6876
# Target code
6977
if args.mode == "target":
@@ -113,25 +121,25 @@ def parse_args():
113121

114122
if not xfer_handle:
115123
print("Creating transfer failed.")
116-
exit()
124+
sys.exit(BlockingSendRecvErrCodes.TRANSFER_FAILED.value)
117125

118126
state = agent.transfer(xfer_handle)
119127
if state == "ERR":
120128
print("Posting transfer failed.")
121-
exit()
129+
sys.exit(BlockingSendRecvErrCodes.TRANSFER_FAILED.value)
122130
while True:
123131
state = agent.check_xfer_state(xfer_handle)
124132
if state == "ERR":
125133
print("Transfer got to Error state.")
126-
exit()
134+
sys.exit(BlockingSendRecvErrCodes.TRANSFER_FAILED.value)
127135
elif state == "DONE":
128136
break
129137

130138
# Verify data after read
131139
for i, tensor in enumerate(tensors):
132140
if not torch.allclose(tensor, torch.ones(10)):
133141
print(f"Data verification failed for tensor {i}.")
134-
exit()
142+
sys.exit(BlockingSendRecvErrCodes.DATA_VERIFICATION_FAILED.value)
135143
print(f"{args.mode} Data verification passed - {tensors}")
136144

137145
if args.mode != "target":

examples/python/nixl_api_example.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,25 @@
1616
# limitations under the License.
1717

1818
import os
19+
import sys
20+
import enum
1921

2022
import numpy as np
2123
import torch
2224

2325
import nixl._utils as nixl_utils
2426
from nixl._api import nixl_agent, nixl_agent_config
2527

28+
from . import util
29+
30+
31+
class NixlApiExampleErrCodes(enum.Enum):
32+
CREATE_TRANSFER_FAILED = 1
33+
PREP_TRANSFER_SIDE_HANDLES_FAILED = 2
34+
MAKE_PREPPED_TRANSFER_FAILED = 3
35+
TRANSFER_FAILED = 4
36+
37+
2638
if __name__ == "__main__":
2739
buf_size = 256
2840
# Allocate memory and register with NIXL
@@ -104,30 +116,13 @@
104116
)
105117
if not xfer_handle_1:
106118
print("Creating transfer failed.")
107-
exit()
119+
sys.exit(NixlApiExampleErrCodes.CREATE_TRANSFER_FAILED.value)
108120

109121
# test multiple postings
110122
for _ in range(2):
111123
state = nixl_agent2.transfer(xfer_handle_1)
112124
assert state != "ERR"
113-
114-
target_done = False
115-
init_done = False
116-
117-
while (not init_done) or (not target_done):
118-
if not init_done:
119-
state = nixl_agent2.check_xfer_state(xfer_handle_1)
120-
if state == "ERR":
121-
print("Transfer got to Error state.")
122-
exit()
123-
elif state == "DONE":
124-
init_done = True
125-
print("Initiator done")
126-
127-
if not target_done:
128-
if nixl_agent1.check_remote_xfer_done("initiator", b"UUID1"):
129-
target_done = True
130-
print("Target done")
125+
util.wait_for_transfer_completion(nixl_agent2, nixl_agent1, xfer_handle_1, b"UUID1")
131126

132127
# prep transfer mode
133128
local_prep_handle = nixl_agent2.prep_xfer_dlist(
@@ -165,11 +160,11 @@
165160
)
166161
if not local_prep_handle or not remote_prep_handle:
167162
print("Preparing transfer side handles failed.")
168-
exit()
163+
sys.exit(NixlApiExampleErrCodes.PREP_TRANSFER_SIDE_HANDLES_FAILED.value)
169164

170165
if not xfer_handle_2:
171166
print("Make prepped transfer failed.")
172-
exit()
167+
sys.exit(NixlApiExampleErrCodes.MAKE_PREPPED_TRANSFER_FAILED.value)
173168

174169
state = nixl_agent2.transfer(xfer_handle_2)
175170
assert state != "ERR"
@@ -184,7 +179,7 @@
184179
state = nixl_agent2.check_xfer_state(xfer_handle_2)
185180
if state == "ERR":
186181
print("Transfer got to Error state.")
187-
exit()
182+
sys.exit(NixlApiExampleErrCodes.TRANSFER_FAILED.value)
188183
elif state == "DONE":
189184
init_done = True
190185
print("Initiator done")

examples/python/nixl_gds_example.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,27 @@
1717

1818
import os
1919
import sys
20+
import enum
2021

2122
import nixl._utils as nixl_utils
2223
from nixl._api import nixl_agent, nixl_agent_config
2324

25+
26+
class NixlGdsExampleErrCodes(enum.Enum):
27+
MISSING_FILE_PATH = 1
28+
CREATE_TRANSFER_FAILED = 2
29+
TRANSFER_FAILED = 3
30+
INIT_XFER_FAILED = 4
31+
DATA_VERIFICATION_FAILED = 5
32+
33+
2434
if __name__ == "__main__":
2535
buf_size = 16 * 4096
2636
# Allocate memory and register with NIXL
2737

2838
if len(sys.argv) < 2:
2939
print("Please specify file path in argv")
30-
exit(0)
40+
sys.exit(NixlGdsExampleErrCodes.MISSING_FILE_PATH.value)
3141

3242
print("Using NIXL Plugins from:")
3343
print(os.environ["NIXL_PLUGIN_DIR"])
@@ -79,7 +89,7 @@
7989
)
8090
if not xfer_handle_1:
8191
print("Creating transfer failed.")
82-
exit()
92+
sys.exit(NixlGdsExampleErrCodes.CREATE_TRANSFER_FAILED.value)
8393

8494
state = nixl_agent1.transfer(xfer_handle_1)
8595
assert state != "ERR"
@@ -90,7 +100,7 @@
90100
state = nixl_agent1.check_xfer_state(xfer_handle_1)
91101
if state == "ERR":
92102
print("Transfer got to Error state.")
93-
exit()
103+
sys.exit(NixlGdsExampleErrCodes.TRANSFER_FAILED.value)
94104
elif state == "DONE":
95105
done = True
96106
print("Initiator done")
@@ -101,7 +111,7 @@
101111
)
102112
if not xfer_handle_2:
103113
print("Creating transfer failed.")
104-
exit()
114+
sys.exit(NixlGdsExampleErrCodes.INIT_XFER_FAILED.value)
105115

106116
state = nixl_agent1.transfer(xfer_handle_2)
107117
assert state != "ERR"
@@ -112,7 +122,7 @@
112122
state = nixl_agent1.check_xfer_state(xfer_handle_2)
113123
if state == "ERR":
114124
print("Transfer got to Error state.")
115-
exit()
125+
sys.exit(NixlGdsExampleErrCodes.TRANSFER_FAILED.value)
116126
elif state == "DONE":
117127
done = True
118128
print("Initiator done")

examples/python/partial_md_example.py

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616
# limitations under the License.
1717

1818
import argparse
19+
import enum
1920
import os
2021

2122
import nixl._utils as nixl_utils
2223
from nixl._api import nixl_agent, nixl_agent_config
2324
from nixl._bindings import nixlNotFoundError
25+
from . import util
2426

2527

28+
class PartialMdExampleErrCodes(enum.Enum):
29+
TRANSFER_FAILED = 1
30+
2631
def exchange_target_metadata(
2732
target_agent,
2833
init_agent,
@@ -62,8 +67,21 @@ def invalidate_target_metadata(
6267
else:
6368
target_agent.invalidate_local_metadata(ip_addr, init_port)
6469

70+
def allocate_strings(malloc_addrs: list, buf_size: int, num_strings: int):
71+
target_strs = []
72+
for _ in range(num_strings):
73+
addr1 = nixl_utils.malloc_passthru(buf_size)
74+
target_strs.append((addr1, buf_size, 0, "test"))
75+
malloc_addrs.append(addr1)
6576

66-
if __name__ == "__main__":
77+
return target_strs
78+
79+
def xfer(init_agent, xfer_handle, target_agent):
80+
state = init_agent.transfer(xfer_handle)
81+
assert state != "ERR"
82+
util.wait_for_transfer_completion(init_agent, target_agent, xfer_handle, b"UUID1")
83+
84+
def main():
6785
buf_size = 256
6886
# Allocate memory and register with NIXL
6987

@@ -102,17 +120,8 @@ def invalidate_target_metadata(
102120

103121
malloc_addrs = []
104122

105-
target_strs1 = []
106-
for _ in range(10):
107-
addr1 = nixl_utils.malloc_passthru(buf_size)
108-
target_strs1.append((addr1, buf_size, 0, "test"))
109-
malloc_addrs.append(addr1)
110-
111-
target_strs2 = []
112-
for _ in range(10):
113-
addr1 = nixl_utils.malloc_passthru(buf_size)
114-
target_strs2.append((addr1, buf_size, 0, "test"))
115-
malloc_addrs.append(addr1)
123+
target_strs1 = allocate_strings(malloc_addrs, buf_size, 10)
124+
target_strs2 = allocate_strings(malloc_addrs, buf_size, 10)
116125

117126
target_reg_descs1 = target_agent.get_reg_descs(target_strs1, "DRAM", is_sorted=True)
118127
target_reg_descs2 = target_agent.get_reg_descs(target_strs2, "DRAM", is_sorted=True)
@@ -126,11 +135,7 @@ def invalidate_target_metadata(
126135
agent_config2 = nixl_agent_config(True, True, init_port)
127136
init_agent = nixl_agent("initiator", agent_config2)
128137

129-
init_strs = []
130-
for _ in range(10):
131-
addr1 = nixl_utils.malloc_passthru(buf_size)
132-
init_strs.append((addr1, buf_size, 0, "test"))
133-
malloc_addrs.append(addr1)
138+
init_strs = allocate_strings(malloc_addrs, buf_size, 10)
134139

135140
init_reg_descs = init_agent.get_reg_descs(init_strs, "DRAM", is_sorted=True)
136141
init_xfer_descs = init_reg_descs.trim()
@@ -157,27 +162,7 @@ def invalidate_target_metadata(
157162
"READ", init_xfer_descs, target_xfer_descs1, "target", b"UUID1"
158163
)
159164

160-
state = init_agent.transfer(xfer_handle_1)
161-
assert state != "ERR"
162-
163-
target_done = False
164-
init_done = False
165-
166-
while (not init_done) or (not target_done):
167-
if not init_done:
168-
state = init_agent.check_xfer_state(xfer_handle_1)
169-
if state == "ERR":
170-
print("Transfer got to Error state.")
171-
exit()
172-
elif state == "DONE":
173-
init_done = True
174-
print("Initiator done")
175-
176-
if not target_done:
177-
if target_agent.check_remote_xfer_done("initiator", b"UUID1"):
178-
target_done = True
179-
print("Target done")
180-
165+
xfer(init_agent, xfer_handle_1, target_agent)
181166
# Second set of descs was not sent, should fail
182167
try:
183168
xfer_handle_2 = init_agent.initialize_xfer(
@@ -214,26 +199,7 @@ def invalidate_target_metadata(
214199
else:
215200
ready = True
216201

217-
state = init_agent.transfer(xfer_handle_2)
218-
assert state != "ERR"
219-
220-
target_done = False
221-
init_done = False
222-
223-
while (not init_done) or (not target_done):
224-
if not init_done:
225-
state = init_agent.check_xfer_state(xfer_handle_2)
226-
if state == "ERR":
227-
print("Transfer got to Error state.")
228-
exit()
229-
elif state == "DONE":
230-
init_done = True
231-
print("Initiator done")
232-
233-
if not target_done:
234-
if target_agent.check_remote_xfer_done("initiator", b"UUID1"):
235-
target_done = True
236-
print("Target done")
202+
xfer(init_agent, xfer_handle_2, target_agent)
237203

238204
init_agent.release_xfer_handle(xfer_handle_1)
239205
init_agent.release_xfer_handle(xfer_handle_2)
@@ -250,3 +216,7 @@ def invalidate_target_metadata(
250216
del target_agent
251217

252218
print("Test Complete.")
219+
220+
221+
if __name__ == "__main__":
222+
main()

examples/python/query_mem_example.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18+
# pylint: disable=broad-exception-caught
19+
1820
import os
1921
import sys
2022
import tempfile
23+
import traceback
2124

2225
try:
2326
from nixl._api import nixl_agent, nixl_agent_config
@@ -27,7 +30,8 @@
2730
print("NIXL API missing install NIXL.")
2831
NIXL_AVAILABLE = False
2932

30-
if __name__ == "__main__":
33+
34+
def main():
3135
print("NIXL queryMem Python API Example")
3236
print("=" * 40)
3337

@@ -111,8 +115,6 @@
111115

112116
except Exception as e:
113117
print(f"Error in example: {e}")
114-
import traceback
115-
116118
traceback.print_exc()
117119

118120
finally:
@@ -122,3 +124,7 @@
122124
if os.path.exists(temp_file_path):
123125
os.unlink(temp_file_path)
124126
print(f"Removed: {temp_file_path}")
127+
128+
129+
if __name__ == "__main__":
130+
main()

0 commit comments

Comments
 (0)