@@ -70,6 +70,11 @@ def is_float_datatype(data_type: DataType):
70
70
def is_numeric_datatype (data_type : DataType ):
71
71
return is_float_datatype (data_type ) or is_integer_datatype (data_type )
72
72
73
+ def is_varchar_datatype (data_type : DataType ):
74
+ return data_type in (DataType .VARCHAR ,)
75
+
76
+ def is_bool_datatype (data_type : DataType ):
77
+ return data_type in (DataType .BOOL ,)
73
78
74
79
# pylint: disable=too-many-return-statements
75
80
def infer_dtype_by_scalar_data (data : Any ):
@@ -105,7 +110,7 @@ def infer_dtype_by_scalar_data(data: Any):
105
110
return DataType .UNKNOWN
106
111
107
112
108
- def infer_dtype_bydata (data : Any ):
113
+ def infer_dtype_bydata (data : Any , ** kargs ):
109
114
d_type = DataType .UNKNOWN
110
115
if is_scalar (data ):
111
116
return infer_dtype_by_scalar_data (data )
@@ -121,7 +126,17 @@ def infer_dtype_bydata(data: Any):
121
126
failed = True
122
127
if not failed :
123
128
d_type = dtype_str_map .get (type_str , DataType .UNKNOWN )
124
- return DataType .FLOAT_VECTOR if is_numeric_datatype (d_type ) else DataType .ARRAY
129
+ if is_varchar_datatype (d_type ) or is_bool_datatype (d_type ):
130
+ return DataType .ARRAY
131
+ if kargs is None or len (kargs ) == 0 :
132
+ return DataType .FLOAT_VECTOR if \
133
+ is_numeric_datatype (d_type ) else DataType .UNKNOWN
134
+ else :
135
+ if kargs ["type" ] is not None and kargs ["type" ] == "vector" :
136
+ return DataType .FLOAT_VECTOR \
137
+ if is_numeric_datatype (d_type ) else DataType .UNKNOWN
138
+ else :
139
+ return DataType .ARRAY
125
140
126
141
if d_type == DataType .UNKNOWN :
127
142
try :
0 commit comments