From 751511c54cd9f319f8c83fe0edebfd0653f875a4 Mon Sep 17 00:00:00 2001 From: stijn Date: Thu, 8 May 2025 11:50:18 +0200 Subject: [PATCH] Fix issue #281 --- kernel_tuner/backends/hip.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/kernel_tuner/backends/hip.py b/kernel_tuner/backends/hip.py index 46d2d50a..f5eb27a6 100644 --- a/kernel_tuner/backends/hip.py +++ b/kernel_tuner/backends/hip.py @@ -19,7 +19,6 @@ "bool": ctypes.c_bool, "int8": ctypes.c_int8, "int16": ctypes.c_int16, - "float16": ctypes.c_int16, "int32": ctypes.c_int32, "int64": ctypes.c_int64, "uint8": ctypes.c_uint8, @@ -120,25 +119,29 @@ def ready_argument_list(self, arguments): # Handle numpy arrays if isinstance(arg, np.ndarray): - if dtype_str in dtype_map.keys(): - # Allocate device memory - device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) + # Allocate device memory + device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) - # Copy data to device using hipMemcpy - hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) + # Copy data to device using hipMemcpy + hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) - prepared_args.append(device_ptr) - else: - raise TypeError(f"Unknown dtype {dtype_str} for ndarray") + prepared_args.append(device_ptr) # Handle numpy scalar types elif isinstance(arg, np.generic): # Convert numpy scalar to corresponding ctypes - ctype_arg = dtype_map[dtype_str](arg) - prepared_args.append(ctype_arg) + if dtype_str in dtype_map: + ctype_arg = dtype_map[dtype_str](arg) + prepared_args.append(ctype_arg) + # 16-bit float is not supported, view it as uint16 + elif dtype_str in ("float16", "bfloat16"): + ctype_arg = ctypes.c_uint16(arg.view(np.uint16)) + prepared_args.append(ctype_arg) + else: + raise ValueError(f"Invalid argument type {dtype_str}: {arg}") else: - raise ValueError(f"Invalid argument type {type(arg)}, {arg}") + raise ValueError(f"Invalid argument type {type(arg)}: {arg}") return prepared_args