Skip to content

Commit 6816ffa

Browse files
authored
[AL-4866] Upsert data rows to model run using global keys
2 parents 113ae68 + aa6716d commit 6816ffa

File tree

4 files changed

+38
-13
lines changed

4 files changed

+38
-13
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
# Version 3.40.0 (YYYY-MM-DD)
4+
5+
## Added
6+
* Upsert data rows to model runs using global keys
7+
38
# Version 3.39.0 (2023-02-28)
49
## Added
510
* New method `Project.task_queues()` to obtain the task queues for a project.

labelbox/schema/model_run.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,30 @@ def upsert_labels(self, label_ids, timeout_seconds=3600):
8080
}})['MEALabelRegistrationTaskStatus'],
8181
timeout_seconds=timeout_seconds)
8282

83-
def upsert_data_rows(self, data_row_ids, timeout_seconds=3600):
83+
def upsert_data_rows(self,
84+
data_row_ids=None,
85+
global_keys=None,
86+
timeout_seconds=3600):
8487
""" Adds data rows to a Model Run without any associated labels
8588
Args:
86-
data_row_ids (list): data row ids to add to mea
89+
data_row_ids (list): data row ids to add to model run
90+
global_keys (list): global keys for data rows to add to model run
8791
timeout_seconds (float): Max waiting time, in seconds.
8892
Returns:
8993
ID of newly generated async task
9094
"""
9195

92-
if len(data_row_ids) < 1:
93-
raise ValueError("Must provide at least one data row id")
94-
9596
mutation_name = 'createMEAModelRunDataRowRegistrationTask'
96-
create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds : [ID!]!) {
97-
%s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds})}
97+
create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds: [ID!], $globalKeys: [ID!]) {
98+
%s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds, globalKeys: $globalKeys})}
9899
""" % (mutation_name)
99100

100-
res = self.client.execute(create_task_query_str, {
101-
'modelRunId': self.uid,
102-
'dataRowIds': data_row_ids
103-
})
101+
res = self.client.execute(
102+
create_task_query_str, {
103+
'modelRunId': self.uid,
104+
'dataRowIds': data_row_ids,
105+
'globalKeys': global_keys
106+
})
104107
task_id = res[mutation_name]
105108

106109
status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){

tests/integration/annotation_import/test_model_run.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import json
21
import time
32
import os
43
import pytest
54

65
from collections import Counter
76

8-
import requests
97
from labelbox import DataSplit, ModelRun
108

119

@@ -99,6 +97,14 @@ def test_model_run_upsert_data_rows(dataset, model_run):
9997
assert n_model_run_data_rows == 1
10098

10199

100+
def test_model_run_upsert_data_rows_using_global_keys(model_run, data_rows):
101+
global_keys = [dr.global_key for dr in data_rows]
102+
assert model_run.upsert_data_rows(global_keys=global_keys)
103+
model_run_data_rows = list(model_run.model_run_data_rows())
104+
added_data_rows = [mdr.data_row() for mdr in model_run_data_rows]
105+
assert set(added_data_rows) == set(data_rows)
106+
107+
102108
def test_model_run_upsert_data_rows_with_existing_labels(
103109
model_run_with_model_run_data_rows):
104110
model_run_data_rows = list(

tests/integration/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,17 @@ def datarow(dataset, image_url):
212212
dr.delete()
213213

214214

215+
@pytest.fixture()
216+
def data_rows(dataset, image_url):
217+
dr1 = dataset.create_data_row(row_data=image_url,
218+
global_key=f"global-key-{uuid.uuid4()}")
219+
dr2 = dataset.create_data_row(row_data=image_url,
220+
global_key=f"global-key-{uuid.uuid4()}")
221+
yield [dr1, dr2]
222+
dr1.delete()
223+
dr2.delete()
224+
225+
215226
@pytest.fixture
216227
def iframe_url(environ) -> str:
217228
if environ in [Environ.PROD, Environ.LOCAL]:

0 commit comments

Comments
 (0)