Skip to content

Commit 2432346

Browse files
author
Val Brodsky
committed
Fix search filters
1 parent 18546d3 commit 2432346

File tree

2 files changed

+58
-51
lines changed

2 files changed

+58
-51
lines changed

libs/labelbox/src/labelbox/schema/search_filters.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,24 @@
11
import datetime
22
from enum import Enum
3-
from typing import List, Literal, Union
3+
from typing import List, Union
4+
from pydantic import PlainSerializer, BaseModel, Field
5+
6+
from typing_extensions import Annotated
47

5-
from pydantic import BaseModel, Field
68
from labelbox.schema.labeling_service_status import LabelingServiceStatus
79
from labelbox.utils import format_iso_datetime
8-
from pydantic.config import ConfigDict
910

1011

1112
class BaseSearchFilter(BaseModel):
1213
"""
1314
Shared code for all search filters
1415
"""
1516

16-
model_config = ConfigDict(use_enum_values=True)
17-
18-
def dict(self, *args, **kwargs):
19-
res = super().dict(*args, **kwargs)
20-
# go through all the keys and convert date to string
21-
for key in res:
22-
if isinstance(res[key], datetime.datetime):
23-
res[key] = format_iso_datetime(res[key])
24-
return res
17+
class Config:
18+
use_enum_values = True
2519

2620

27-
class OperationType(Enum):
21+
class OperationTypeEnum(Enum):
2822
"""
2923
Supported search entity types
3024
Each type corresponds to a different filter class
@@ -40,6 +34,13 @@ class OperationType(Enum):
4034
TaskRemainingCount = 'task_remaining_count'
4135

4236

37+
OperationType = Annotated[OperationTypeEnum,
38+
PlainSerializer(lambda x: x.value, return_type=str)]
39+
40+
IsoDatetimeType = Annotated[datetime.datetime,
41+
PlainSerializer(format_iso_datetime)]
42+
43+
4344
class IdOperator(Enum):
4445
"""
4546
Supported operators for ids like org ids, workspace ids, etc
@@ -75,7 +76,8 @@ class OrganizationFilter(BaseSearchFilter):
7576
"""
7677
Filter for organization to which projects belong
7778
"""
78-
operation: Literal[OperationType.Organization] = OperationType.Organization
79+
operation: OperationType = Field(default=OperationTypeEnum.Organization,
80+
serialization_alias='type')
7981
operator: IdOperator
8082
values: List[str]
8183

@@ -84,9 +86,10 @@ class SharedWithOrganizationFilter(BaseSearchFilter):
8486
"""
8587
Find project shared with the organization (i.e. not having this organization as a tenantId)
8688
"""
87-
operation: Literal[
88-
OperationType.
89-
SharedWithOrganization] = OperationType.SharedWithOrganization
89+
90+
operation: OperationType = Field(
91+
default=OperationTypeEnum.SharedWithOrganization,
92+
serialization_alias='type')
9093
operator: IdOperator
9194
values: List[str]
9295

@@ -95,7 +98,8 @@ class WorkspaceFilter(BaseSearchFilter):
9598
"""
9699
Filter for workspace
97100
"""
98-
operation: Literal[OperationType.Workspace] = OperationType.Workspace
101+
operation: OperationType = Field(default=OperationTypeEnum.Workspace,
102+
serialization_alias='type')
99103
operator: IdOperator
100104
values: List[str]
101105

@@ -104,7 +108,8 @@ class TagFilter(BaseSearchFilter):
104108
"""
105109
Filter for project tags
106110
"""
107-
operation: Literal[OperationType.Tag] = OperationType.Tag
111+
operation: OperationType = Field(default=OperationTypeEnum.Tag,
112+
serialization_alias='type')
108113
operator: IdOperator
109114
values: List[str]
110115

@@ -114,7 +119,8 @@ class ProjectStageFilter(BaseSearchFilter):
114119
Filter labelbox service / aka project stages
115120
Stages are: requested, in_progress, completed etc. as described by LabelingServiceStatus
116121
"""
117-
operation: Literal[OperationType.Stage] = OperationType.Stage
122+
operation: OperationType = Field(default=OperationTypeEnum.Stage,
123+
serialization_alias='type')
118124
operator: IdOperator
119125
values: List[LabelingServiceStatus]
120126

@@ -132,7 +138,7 @@ class DateValue(BaseSearchFilter):
132138
while the same string in EST will get converted to '2024-01-01T05:00:00Z'
133139
"""
134140
operator: RangeDateTimeOperatorWithSingleValue
135-
value: datetime.datetime
141+
value: IsoDatetimeType
136142

137143

138144
class IntegerValue(BaseSearchFilter):
@@ -144,28 +150,28 @@ class WorkforceStageUpdatedFilter(BaseSearchFilter):
144150
"""
145151
Filter for workforce stage updated date
146152
"""
147-
operation: Literal[
148-
OperationType.
149-
WorkforceStageUpdatedDate] = OperationType.WorkforceStageUpdatedDate
153+
operation: OperationType = Field(
154+
default=OperationTypeEnum.WorkforceStageUpdatedDate,
155+
serialization_alias='type')
150156
value: DateValue
151157

152158

153159
class WorkforceRequestedDateFilter(BaseSearchFilter):
154160
"""
155161
Filter for workforce requested date
156162
"""
157-
operation: Literal[
158-
OperationType.
159-
WorforceRequestedDate] = OperationType.WorforceRequestedDate
163+
operation: OperationType = Field(
164+
default=OperationTypeEnum.WorforceRequestedDate,
165+
serialization_alias='type')
160166
value: DateValue
161167

162168

163169
class DateRange(BaseSearchFilter):
164170
"""
165171
Date range for a search filter
166172
"""
167-
min: datetime.datetime
168-
max: datetime.datetime
173+
min: IsoDatetimeType
174+
max: IsoDatetimeType
169175

170176

171177
class DateRangeValue(BaseSearchFilter):
@@ -180,19 +186,19 @@ class WorkforceRequestedDateRangeFilter(BaseSearchFilter):
180186
"""
181187
Filter for workforce requested date range
182188
"""
183-
operation: Literal[
184-
OperationType.
185-
WorforceRequestedDate] = OperationType.WorforceRequestedDate
189+
operation: OperationType = Field(
190+
default=OperationTypeEnum.WorforceRequestedDate,
191+
serialization_alias='type')
186192
value: DateRangeValue
187193

188194

189195
class WorkforceStageUpdatedRangeFilter(BaseSearchFilter):
190196
"""
191197
Filter for workforce stage updated date range
192198
"""
193-
operation: Literal[
194-
OperationType.
195-
WorkforceStageUpdatedDate] = OperationType.WorkforceStageUpdatedDate
199+
operation: OperationType = Field(
200+
default=OperationTypeEnum.WorkforceStageUpdatedDate,
201+
serialization_alias='type')
196202
value: DateRangeValue
197203

198204

@@ -201,20 +207,19 @@ class TaskCompletedCountFilter(BaseSearchFilter):
201207
Filter for completed tasks count
202208
A task maps to a data row. Task completed should map to a data row in a labeling queue DONE
203209
"""
204-
operation: Literal[
205-
OperationType.TaskCompletedCount] = Field(default=OperationType.TaskCompletedCount, serialization_alias='type')
210+
operation: OperationType = Field(
211+
default=OperationTypeEnum.TaskCompletedCount,
212+
serialization_alias='type')
206213
value: IntegerValue
207214

208215

209-
210-
211-
212216
class TaskRemainingCountFilter(BaseSearchFilter):
213217
"""
214218
Filter for remaining tasks count. Reverse of TaskCompletedCountFilter
215219
"""
216-
operation: Literal[
217-
OperationType.TaskRemainingCount] = Field(OperationType.TaskRemainingCount, serialization_alias='type')
220+
operation: OperationType = Field(
221+
default=OperationTypeEnum.TaskRemainingCount,
222+
serialization_alias='type')
218223
value: IntegerValue
219224

220225

@@ -242,5 +247,7 @@ def build_search_filter(filter: List[SearchFilter]):
242247
"""
243248
Converts a list of search filters to a graphql string
244249
"""
245-
filters = [_dict_to_graphql_string(f.model_dump(by_alias=True)) for f in filter]
250+
filters = [
251+
_dict_to_graphql_string(f.model_dump(by_alias=True)) for f in filter
252+
]
246253
return "[" + ", ".join(filters) + "]"

libs/labelbox/tests/unit/test_unit_search_filters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_id_filters():
1919

2020
assert build_search_filter(
2121
filters
22-
) == '[{operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"], type: "organization_id"}, {operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"], type: "shared_with_organizations"}, {operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"], type: "workspace"}, {operator: "is", values: ["tag"], type: "tag"}, {operator: "is", values: ["REQUESTED"], type: "stage"}]'
22+
) == '[{type: "organization_id", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "shared_with_organizations", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "workspace", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "tag", operator: "is", values: ["tag"]}, {type: "stage", operator: "is", values: ["REQUESTED"]}]'
2323

2424

2525
def test_date_filters():
@@ -37,7 +37,7 @@ def test_date_filters():
3737
expected_start = format_iso_datetime(local_time_start)
3838
expected_end = format_iso_datetime(local_time_end)
3939

40-
expected = '[{value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + expected_start + '"}, type: "workforce_requested_at"}, {value: {operator: "LESS_THAN_OR_EQUAL", value: "' + expected_end + '"}, type: "workforce_stage_updated_at"}]'
40+
expected = '[{type: "workforce_requested_at", value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + expected_start + '"}}, {type: "workforce_stage_updated_at", value: {operator: "LESS_THAN_OR_EQUAL", value: "' + expected_end + '"}}]'
4141
assert build_search_filter(filters) == expected
4242

4343

@@ -58,16 +58,16 @@ def test_date_range_filters():
5858
]
5959
assert build_search_filter(
6060
filters
61-
) == '[{value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}, type: "workforce_requested_at"}, {value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}, type: "workforce_stage_updated_at"}]'
61+
) == '[{type: "workforce_requested_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}, {type: "workforce_stage_updated_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}]'
6262

6363

6464
def test_task_count_filters():
6565
filters = [
66-
TaskCompletedCountFilter(value=IntegerValue(operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=1)),
67-
# TaskRemainingCountFilter(value=IntegerValue(
68-
# operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10)),
66+
TaskCompletedCountFilter(value=IntegerValue(
67+
operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=1)),
68+
TaskRemainingCountFilter(value=IntegerValue(
69+
operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10)),
6970
]
7071

71-
# expected = '[{value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}, type: "task_completed_count"}, {value: {operator: "LESS_THAN_OR_EQUAL", value: 10}, type: "task_remaining_count"}]'
72-
expected = '[{value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}, type: "task_completed_count"}, {value: {operator: "LESS_THAN_OR_EQUAL", value: 10}, type: "task_remaining_count"}]'
72+
expected = '[{type: "task_completed_count", value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}}, {type: "task_remaining_count", value: {operator: "LESS_THAN_OR_EQUAL", value: 10}}]'
7373
assert build_search_filter(filters) == expected

0 commit comments

Comments
 (0)