Skip to content

Commit 11a8f6c

Browse files
committed
Change dtype check in default verify function, fixes issue 245
1 parent 10499af commit 11a8f6c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

kernel_tuner/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,9 +815,9 @@ def _default_verify_function(instance, answer, result_host, atol, verbose):
815815
if isinstance(answer[i], (np.ndarray, cp.ndarray)) and isinstance(
816816
arg, (np.ndarray, cp.ndarray)
817817
):
818-
if answer[i].dtype != arg.dtype:
818+
if not np.can_cast(arg.dtype, answer[i].dtype):
819819
raise TypeError(
820-
f"Element {i} of the expected results list is not of the same dtype as the kernel output: "
820+
f"Element {i} of the expected results list has a dtype that is not compatible with the dtype of the kernel output: "
821821
+ str(answer[i].dtype)
822822
+ " != "
823823
+ str(arg.dtype)

0 commit comments

Comments
 (0)