Skip to content

Commit 54ab49c

Browse files
feat(LAB-3088): add LLM models and project models configuration (#1774)
Co-authored-by: paulruelle <[email protected]>
1 parent 6cf2aba commit 54ab49c

File tree

17 files changed

+977
-80
lines changed

17 files changed

+977
-80
lines changed

src/kili/adapters/kili_api_gateway/kili_api_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from kili.adapters.kili_api_gateway.cloud_storage import CloudStorageOperationMixin
77
from kili.adapters.kili_api_gateway.issue import IssueOperationMixin
88
from kili.adapters.kili_api_gateway.label.operations_mixin import LabelOperationMixin
9-
from kili.adapters.kili_api_gateway.model_configuration.operations_mixin import (
9+
from kili.adapters.kili_api_gateway.llm.operations_mixin import (
1010
ModelConfigurationOperationMixin,
1111
)
1212
from kili.adapters.kili_api_gateway.notification.operations_mixin import (
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""GraphQL payload data mappers for api keys operations."""
2+
3+
from typing import Dict
4+
5+
from kili.domain.llm import (
6+
AzureOpenAICredentials,
7+
ModelToCreateInput,
8+
ModelToUpdateInput,
9+
ModelType,
10+
OpenAISDKCredentials,
11+
OrganizationModelFilters,
12+
ProjectModelFilters,
13+
ProjectModelToCreateInput,
14+
ProjectModelToUpdateInput,
15+
)
16+
17+
18+
def model_where_wrapper(filter: OrganizationModelFilters) -> Dict:
19+
"""Build the GraphQL ProjectMapperWhere variable to be sent in an operation."""
20+
return {
21+
"organizationId": filter.organization_id,
22+
}
23+
24+
25+
def project_model_where_mapper(filter: ProjectModelFilters) -> Dict:
26+
"""Build the GraphQL ProjectMapperWhere variable to be sent in an operation."""
27+
return {
28+
"projectId": filter.project_id,
29+
"modelId": filter.model_id,
30+
}
31+
32+
33+
def map_create_model_input(data: ModelToCreateInput) -> Dict:
34+
"""Build the GraphQL ModelInput variable to be sent in an operation."""
35+
if data.type == ModelType.AZURE_OPEN_AI and isinstance(
36+
data.credentials, AzureOpenAICredentials
37+
):
38+
credentials = {
39+
"apiKey": data.credentials.api_key,
40+
"deploymentId": data.credentials.deployment_id,
41+
"endpoint": data.credentials.endpoint,
42+
}
43+
elif data.type == ModelType.OPEN_AI_SDK and isinstance(data.credentials, OpenAISDKCredentials):
44+
credentials = {"apiKey": data.credentials.api_key, "endpoint": data.credentials.endpoint}
45+
else:
46+
raise ValueError(
47+
f"Unsupported model type or credentials: {data.type}, {type(data.credentials)}"
48+
)
49+
50+
return {
51+
"credentials": credentials,
52+
"name": data.name,
53+
"type": data.type.value,
54+
"organizationId": data.organization_id,
55+
}
56+
57+
58+
def map_update_model_input(data: ModelToUpdateInput) -> Dict:
59+
"""Build the GraphQL UpdateModelInput variable to be sent in an operation."""
60+
input_dict = {}
61+
if data.name is not None:
62+
input_dict["name"] = data.name
63+
64+
if data.credentials is not None:
65+
if isinstance(data.credentials, AzureOpenAICredentials):
66+
credentials = {
67+
"apiKey": data.credentials.api_key,
68+
"deploymentId": data.credentials.deployment_id,
69+
"endpoint": data.credentials.endpoint,
70+
}
71+
elif isinstance(data.credentials, OpenAISDKCredentials):
72+
credentials = {
73+
"apiKey": data.credentials.api_key,
74+
"endpoint": data.credentials.endpoint,
75+
}
76+
else:
77+
raise ValueError(f"Unsupported credentials type: {type(data.credentials)}")
78+
input_dict["credentials"] = credentials
79+
80+
return input_dict
81+
82+
83+
def map_create_project_model_input(data: ProjectModelToCreateInput) -> Dict:
84+
"""Build the GraphQL ModelInput variable to be sent in an operation."""
85+
return {
86+
"projectId": data.project_id,
87+
"modelId": data.model_id,
88+
"configuration": data.configuration,
89+
}
90+
91+
92+
def map_update_project_model_input(data: ProjectModelToUpdateInput) -> Dict:
93+
"""Build the GraphQL UpdateProjectModelInput variable to be sent in an operation."""
94+
input_dict = {}
95+
if data.configuration is not None:
96+
input_dict["configuration"] = data.configuration
97+
return input_dict
98+
99+
100+
def map_delete_model_input(model_id: str) -> Dict:
101+
"""Map the input for the GraphQL deleteModel mutation."""
102+
return {
103+
"deleteModelId": model_id,
104+
}
105+
106+
107+
def map_delete_project_model_input(project_model_id: str) -> Dict:
108+
"""Map the input for the GraphQL deleteProjectModel mutation."""
109+
return {
110+
"deleteProjectModelId": project_model_id,
111+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""GraphQL Asset operations."""
2+
3+
4+
def get_models_query(fragment: str) -> str:
5+
"""Return the GraphQL projectModels query."""
6+
return f"""
7+
query Models($where: ModelWhere!, $first: PageSize!, $skip: Int!) {{
8+
data: models(where: $where, first: $first, skip: $skip) {{
9+
{fragment}
10+
}}
11+
}}
12+
"""
13+
14+
15+
def get_model_query(fragment: str) -> str:
16+
"""Return the GraphQL model query by ID."""
17+
return f"""
18+
query Model($modelId: ID!) {{
19+
model(id: $modelId) {{
20+
{fragment}
21+
}}
22+
}}
23+
"""
24+
25+
26+
def get_create_model_mutation(fragment: str) -> str:
27+
"""Return the GraphQL createProjectModel mutation."""
28+
return f"""
29+
mutation CreateModel($input: CreateModelInput!) {{
30+
createModel(input: $input) {{
31+
{fragment}
32+
}}
33+
}}
34+
"""
35+
36+
37+
def get_update_model_mutation(fragment: str) -> str:
38+
"""Return the GraphQL updateModel mutation."""
39+
return f"""
40+
mutation UpdateModel($id: ID!, $input: UpdateModelInput!) {{
41+
updateModel(id: $id, input: $input) {{
42+
{fragment}
43+
}}
44+
}}
45+
"""
46+
47+
48+
def get_delete_model_mutation() -> str:
49+
"""Return the GraphQL deleteOrganizationModel mutation."""
50+
return """
51+
mutation DeleteModel($deleteModelId: ID!) {
52+
deleteModel(id: $deleteModelId)
53+
}
54+
"""
55+
56+
57+
def get_create_project_model_mutation(fragment: str) -> str:
58+
"""Return the GraphQL createProjectModel mutation."""
59+
return f"""
60+
mutation CreateProjectModel($input: CreateProjectModelInput!) {{
61+
createProjectModel(input: $input) {{
62+
{fragment}
63+
}}
64+
}}
65+
"""
66+
67+
68+
def get_update_project_model_mutation(fragment: str) -> str:
69+
"""Return the GraphQL updateProjectModel mutation."""
70+
return f"""
71+
mutation UpdateProjectModel($updateProjectModelId: ID!, $input: UpdateProjectModelInput!) {{
72+
updateProjectModel(id: $updateProjectModelId, input: $input) {{
73+
{fragment}
74+
}}
75+
}}
76+
"""
77+
78+
79+
def get_delete_project_model_mutation() -> str:
80+
"""Return the GraphQL deleteProjectModel mutation."""
81+
return """
82+
mutation DeleteProjectModel($deleteProjectModelId: ID!) {
83+
deleteProjectModel(id: $deleteProjectModelId)
84+
}
85+
"""
86+
87+
88+
def get_project_models_query(fragment: str) -> str:
89+
"""Return the GraphQL projectModels query."""
90+
return f"""
91+
query ProjectModels($where: ProjectModelWhere!, $first: PageSize!, $skip: Int!) {{
92+
data: projectModels(where: $where, first: $first, skip: $skip) {{
93+
{fragment}
94+
}}
95+
}}
96+
"""
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Mixin extending Kili API Gateway class with Api Keys related operations."""
2+
3+
from typing import Dict, Generator, Optional
4+
5+
from kili.adapters.kili_api_gateway.base import BaseOperationMixin
6+
from kili.adapters.kili_api_gateway.helpers.queries import (
7+
PaginatedGraphQLQuery,
8+
QueryOptions,
9+
fragment_builder,
10+
)
11+
from kili.adapters.kili_api_gateway.llm.mappers import (
12+
map_create_model_input,
13+
map_create_project_model_input,
14+
map_delete_model_input,
15+
map_delete_project_model_input,
16+
map_update_model_input,
17+
map_update_project_model_input,
18+
model_where_wrapper,
19+
project_model_where_mapper,
20+
)
21+
from kili.adapters.kili_api_gateway.llm.operations import (
22+
get_create_model_mutation,
23+
get_create_project_model_mutation,
24+
get_delete_model_mutation,
25+
get_delete_project_model_mutation,
26+
get_model_query,
27+
get_models_query,
28+
get_project_models_query,
29+
get_update_model_mutation,
30+
get_update_project_model_mutation,
31+
)
32+
from kili.domain.llm import (
33+
ModelToCreateInput,
34+
ModelToUpdateInput,
35+
OrganizationModelFilters,
36+
ProjectModelFilters,
37+
ProjectModelToCreateInput,
38+
ProjectModelToUpdateInput,
39+
)
40+
from kili.domain.types import ListOrTuple
41+
42+
43+
class ModelConfigurationOperationMixin(BaseOperationMixin):
44+
"""Mixin extending Kili API Gateway class with model configuration related operations."""
45+
46+
def list_models(
47+
self,
48+
filters: OrganizationModelFilters,
49+
fields: ListOrTuple[str],
50+
options: Optional[QueryOptions] = None,
51+
) -> Generator[Dict, None, None]:
52+
"""List models with given options."""
53+
fragment = fragment_builder(fields)
54+
query = get_models_query(fragment)
55+
where = model_where_wrapper(filters)
56+
return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
57+
query,
58+
where,
59+
options if options else QueryOptions(disable_tqdm=False),
60+
"Retrieving organization models",
61+
None,
62+
)
63+
64+
def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Dict:
65+
"""Get a model by ID."""
66+
fragment = fragment_builder(fields)
67+
query = get_model_query(fragment)
68+
variables = {"modelId": model_id}
69+
result = self.graphql_client.execute(query, variables)
70+
return result["model"]
71+
72+
def create_model(self, model: ModelToCreateInput) -> Dict:
73+
"""Send a GraphQL request calling createModel resolver."""
74+
payload = {"input": map_create_model_input(model)}
75+
fragment = fragment_builder(["id"])
76+
mutation = get_create_model_mutation(fragment)
77+
result = self.graphql_client.execute(mutation, payload)
78+
return result["createModel"]
79+
80+
def update_properties_in_model(self, model_id: str, model: ModelToUpdateInput) -> Dict:
81+
"""Send a GraphQL request calling updateModel resolver."""
82+
payload = {"id": model_id, "input": map_update_model_input(model)}
83+
fragment = fragment_builder(["id"])
84+
mutation = get_update_model_mutation(fragment)
85+
result = self.graphql_client.execute(mutation, payload)
86+
return result["updateModel"]
87+
88+
def delete_model(self, model_id: str) -> Dict:
89+
"""Send a GraphQL request to delete an organization model."""
90+
payload = map_delete_model_input(model_id)
91+
mutation = get_delete_model_mutation()
92+
result = self.graphql_client.execute(mutation, payload)
93+
return result["deleteModel"]
94+
95+
def create_project_model(self, project_model: ProjectModelToCreateInput) -> Dict:
96+
"""Send a GraphQL request calling createModel resolver."""
97+
payload = {"input": map_create_project_model_input(project_model)}
98+
fragment = fragment_builder(["id"])
99+
mutation = get_create_project_model_mutation(fragment)
100+
result = self.graphql_client.execute(mutation, payload)
101+
return result["createProjectModel"]
102+
103+
def update_project_model(
104+
self, project_model_id: str, project_model: ProjectModelToUpdateInput
105+
) -> Dict:
106+
"""Send a GraphQL request calling updateProjectModel resolver."""
107+
payload = {
108+
"updateProjectModelId": project_model_id,
109+
"input": map_update_project_model_input(project_model),
110+
}
111+
fragment = fragment_builder(["id", "configuration"])
112+
mutation = get_update_project_model_mutation(fragment)
113+
result = self.graphql_client.execute(mutation, payload)
114+
return result["updateProjectModel"]
115+
116+
def delete_project_model(self, project_model_id: str) -> Dict:
117+
"""Send a GraphQL request to delete a project model."""
118+
payload = map_delete_project_model_input(project_model_id)
119+
mutation = get_delete_project_model_mutation()
120+
result = self.graphql_client.execute(mutation, payload)
121+
return result["deleteProjectModel"]
122+
123+
def list_project_models(
124+
self,
125+
filters: ProjectModelFilters,
126+
fields: ListOrTuple[str],
127+
options: Optional[QueryOptions] = None,
128+
) -> Generator[Dict, None, None]:
129+
"""List project models with given options."""
130+
fragment = fragment_builder(fields)
131+
query = get_project_models_query(fragment)
132+
where = project_model_where_mapper(filters)
133+
return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
134+
query,
135+
where,
136+
options if options else QueryOptions(disable_tqdm=False),
137+
"Retrieving project models",
138+
None,
139+
)

src/kili/adapters/kili_api_gateway/model_configuration/mappers.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/kili/adapters/kili_api_gateway/model_configuration/operations.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)