Skip to content

Commit f596d43

Browse files
Merge pull request #300 from KernelTuner/fix-pycuda-fp16
Add support for half and bfloat16 scalars in pyCUDA backend
2 parents 10499af + 2a70264 commit f596d43

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

kernel_tuner/backends/pycuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def ready_argument_list(self, arguments):
180180
# pycuda does not support bool, convert to uint8 instead
181181
elif isinstance(arg, np.bool_):
182182
gpu_args.append(arg.astype(np.uint8))
183+
# pycuda does not support 16-bit formats, view them as uint16
184+
elif isinstance(arg, np.generic) and str(arg.dtype) in ("float16", "bfloat16"):
185+
gpu_args.append(arg.view(np.uint16))
183186
# if not an array, just pass argument along
184187
else:
185188
gpu_args.append(arg)

0 commit comments

Comments
 (0)