Skip to content

Commit 9b8f64e

Browse files
Add PredictionModel and Prediction
1 parent 3c4f91d commit 9b8f64e

File tree

8 files changed

+132
-0
lines changed

8 files changed

+132
-0
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## Not yet released
4+
5+
### Added
6+
* `Prediction` and `PredictionModel` data types.
7+
38
## Version 2.3 (2019-11-12)
49

510
### Changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from labelbox.schema.labeling_frontend import LabelingFrontend
1313
from labelbox.schema.asset_metadata import AssetMetadata
1414
from labelbox.schema.webhook import Webhook
15+
from labelbox.schema.prediction import Prediction, PredictionModel

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
import labelbox.schema.task
1111
import labelbox.schema.user
1212
import labelbox.schema.webhook
13+
import labelbox.schema.prediction

labelbox/schema/data_row.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class DataRow(DbObject, Updateable, BulkDeletable):
1919
organization = Relationship.ToOne("Organization", False)
2020
labels = Relationship.ToMany("Label", True)
2121
metadata = Relationship.ToMany("AssetMetadata", False, "metadata")
22+
predictions = Relationship.ToMany("Prediction", False)
2223

2324
@staticmethod
2425
def bulk_delete(data_rows):

labelbox/schema/prediction.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field, Relationship
3+
4+
5+
class PredictionModel(DbObject):
6+
""" A prediction model represents a specific version of a model. """
7+
updated_at = Field.DateTime("updated_at")
8+
created_at = Field.DateTime("created_at")
9+
created_by = Relationship.ToOne("User", False, "created_by")
10+
organization = Relationship.ToOne("Organization", False)
11+
12+
name = Field.String("name")
13+
slug = Field.String("slug")
14+
version = Field.Int("version")
15+
16+
created_predictions = Relationship.ToMany("Prediction", False,
17+
"created_predictions")
18+
19+
20+
class Prediction(DbObject):
21+
""" A prediction created by a PredictionModel. """
22+
updated_at = Field.DateTime("updated_at")
23+
created_at = Field.DateTime("created_at")
24+
organization = Relationship.ToOne("Organization", False)
25+
26+
label = Field.String("label")
27+
agreement = Field.Float("agreement")
28+
29+
prediction_model = Relationship.ToOne("PredictionModel", False)
30+
data_row = Relationship.ToOne("DataRow", False)
31+
project = Relationship.ToOne("Project", False)

labelbox/schema/project.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class Project(DbObject, Updateable, Deletable):
4040
"LabelingParameterOverride", False, "labeling_parameter_overrides")
4141
webhooks = Relationship.ToMany("Webhook", False)
4242
benchmarks = Relationship.ToMany("Benchmark", False)
43+
active_prediction_model = Relationship.ToOne("PredictionModel", False,
44+
"active_prediction_model")
45+
predictions = Relationship.ToMany("Prediction", False)
4346

4447
def create_label(self, **kwargs):
4548
""" Creates a label on this Project.
@@ -283,6 +286,59 @@ def extend_reservations(self, queue_type):
283286
res = self.client.execute(query_str, {id_param: self.uid})
284287
return res["extendReservations"]
285288

289+
def create_prediction_model(self, name, version):
290+
""" Creates a PredictionModel connected to this Project.
291+
Args:
292+
name (str): The new PredictionModel's name.
293+
version (int): The new PredictionModel's version.
294+
Return:
295+
A newly created PredictionModel.
296+
"""
297+
PM = Entity.PredictionModel
298+
model = self.client._create(
299+
PM, {PM.name.name: name, PM.version.name: version})
300+
self.active_prediction_model.connect(model)
301+
return model
302+
303+
def create_prediction(self, label, data_row, prediction_model=None):
304+
""" Creates a Prediction within this Project.
305+
Args:
306+
label (str): The `label` field of the new Prediction.
307+
data_row (DataRow): The DataRow for which the Prediction is created.
308+
prediction_model (PredictionModel or None): The PredictionModel
309+
within which the new Prediction is created. If None then this
310+
Project's active_prediction_model is used.
311+
Return:
312+
A newly created Prediction.
313+
Raises:
314+
labelbox.excepions.InvalidQueryError: if given `prediction_model`
315+
is None and this Project's active_prediction_model is also
316+
None.
317+
"""
318+
if prediction_model is None:
319+
prediction_model = self.active_prediction_model()
320+
if prediction_model is None:
321+
raise InvalidQueryError(
322+
"Project '%s' has no active prediction model" % self.name)
323+
324+
label_param = "label"
325+
model_param = "prediction_model_id"
326+
project_param = "project_id"
327+
data_row_param = "data_row_id"
328+
329+
Prediction = Entity.Prediction
330+
query_str = """mutation CreatePredictionPyApi(
331+
$%s: String!, $%s: ID!, $%s: ID!, $%s: ID!) {createPrediction(
332+
data: {label: $%s, predictionModelId: $%s, projectId: $%s,
333+
dataRowId: $%s})
334+
{%s}}""" % (label_param, model_param, project_param, data_row_param,
335+
label_param, model_param, project_param, data_row_param,
336+
query.results_query_part(Prediction))
337+
params = {label_param: label, model_param: prediction_model.uid,
338+
data_row_param: data_row.uid, project_param: self.uid}
339+
res = self.client.execute(query_str, params)
340+
return Prediction(self.client, res["createPrediction"])
341+
286342

287343
class LabelingParameterOverride(DbObject):
288344
priority = Field.Int("priority")

tests/integration/test_predictions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def test_prediction_model(project, rand_gen):
2+
model_name = rand_gen(str)
3+
version = 42
4+
model_1 = project.create_prediction_model(model_name, version)
5+
assert model_1.name == model_name
6+
assert model_1.version == version
7+
assert project.active_prediction_model() == model_1
8+
9+
model_2 = project.create_prediction_model(rand_gen(str), 22)
10+
assert project.active_prediction_model() == model_2
11+
12+
13+
def test_predictions(label_pack, rand_gen):
14+
project, dataset, data_row, label = label_pack
15+
model_1 = project.create_prediction_model(rand_gen(str), 12)
16+
17+
assert set(project.predictions()) == set()
18+
pred_1 = project.create_prediction("l1", data_row)
19+
assert pred_1.label == "l1"
20+
assert set(model_1.created_predictions()) == {pred_1}
21+
assert set(project.predictions()) == {pred_1}
22+
assert set(data_row.predictions()) == {pred_1}
23+
assert pred_1.prediction_model() == model_1
24+
assert pred_1.data_row() == data_row
25+
assert pred_1.project() == project
26+
label_2 = project.create_label(data_row=data_row, label="test",
27+
seconds_to_label=0.0)
28+
29+
model_2 = project.create_prediction_model(rand_gen(str), 12)
30+
assert set(project.predictions()) == {pred_1}
31+
pred_2 = project.create_prediction("l2", data_row)
32+
assert pred_2.label == "l2"
33+
assert set(model_1.created_predictions()) == {pred_1}
34+
assert set(model_2.created_predictions()) == {pred_2}
35+
assert set(project.predictions()) == {pred_1, pred_2}
36+
assert set(data_row.predictions()) == {pred_1, pred_2}

tools/db_object_doc_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
labelbox.Project, labelbox.Dataset, labelbox.DataRow, labelbox.Label,
4545
labelbox.AssetMetadata, labelbox.LabelingFrontend, labelbox.Task,
4646
labelbox.Webhook, labelbox.User, labelbox.Organization, labelbox.Review,
47+
labelbox.Prediction, labelbox.PredictionModel,
4748
LabelerPerformance]
4849

4950
ERROR_CLASSES = [LabelboxError] + LabelboxError.__subclasses__()

0 commit comments

Comments
 (0)