|
| 1 | +"""Mixin extending Kili API Gateway class with Api Keys related operations.""" |
| 2 | + |
| 3 | +from typing import Dict, Generator, Optional |
| 4 | + |
| 5 | +from kili.adapters.kili_api_gateway.base import BaseOperationMixin |
| 6 | +from kili.adapters.kili_api_gateway.helpers.queries import ( |
| 7 | + PaginatedGraphQLQuery, |
| 8 | + QueryOptions, |
| 9 | + fragment_builder, |
| 10 | +) |
| 11 | +from kili.adapters.kili_api_gateway.llm.mappers import ( |
| 12 | + map_create_model_input, |
| 13 | + map_create_project_model_input, |
| 14 | + map_delete_model_input, |
| 15 | + map_delete_project_model_input, |
| 16 | + map_update_model_input, |
| 17 | + map_update_project_model_input, |
| 18 | + model_where_wrapper, |
| 19 | + project_model_where_mapper, |
| 20 | +) |
| 21 | +from kili.adapters.kili_api_gateway.llm.operations import ( |
| 22 | + get_create_model_mutation, |
| 23 | + get_create_project_model_mutation, |
| 24 | + get_delete_model_mutation, |
| 25 | + get_delete_project_model_mutation, |
| 26 | + get_model_query, |
| 27 | + get_models_query, |
| 28 | + get_project_models_query, |
| 29 | + get_update_model_mutation, |
| 30 | + get_update_project_model_mutation, |
| 31 | +) |
| 32 | +from kili.domain.llm import ( |
| 33 | + ModelToCreateInput, |
| 34 | + ModelToUpdateInput, |
| 35 | + OrganizationModelFilters, |
| 36 | + ProjectModelFilters, |
| 37 | + ProjectModelToCreateInput, |
| 38 | + ProjectModelToUpdateInput, |
| 39 | +) |
| 40 | +from kili.domain.types import ListOrTuple |
| 41 | + |
| 42 | + |
| 43 | +class ModelConfigurationOperationMixin(BaseOperationMixin): |
| 44 | + """Mixin extending Kili API Gateway class with model configuration related operations.""" |
| 45 | + |
| 46 | + def list_models( |
| 47 | + self, |
| 48 | + filters: OrganizationModelFilters, |
| 49 | + fields: ListOrTuple[str], |
| 50 | + options: Optional[QueryOptions] = None, |
| 51 | + ) -> Generator[Dict, None, None]: |
| 52 | + """List models with given options.""" |
| 53 | + fragment = fragment_builder(fields) |
| 54 | + query = get_models_query(fragment) |
| 55 | + where = model_where_wrapper(filters) |
| 56 | + return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call( |
| 57 | + query, |
| 58 | + where, |
| 59 | + options if options else QueryOptions(disable_tqdm=False), |
| 60 | + "Retrieving organization models", |
| 61 | + None, |
| 62 | + ) |
| 63 | + |
| 64 | + def get_model(self, model_id: str, fields: ListOrTuple[str]) -> Dict: |
| 65 | + """Get a model by ID.""" |
| 66 | + fragment = fragment_builder(fields) |
| 67 | + query = get_model_query(fragment) |
| 68 | + variables = {"modelId": model_id} |
| 69 | + result = self.graphql_client.execute(query, variables) |
| 70 | + return result["model"] |
| 71 | + |
| 72 | + def create_model(self, model: ModelToCreateInput) -> Dict: |
| 73 | + """Send a GraphQL request calling createModel resolver.""" |
| 74 | + payload = {"input": map_create_model_input(model)} |
| 75 | + fragment = fragment_builder(["id"]) |
| 76 | + mutation = get_create_model_mutation(fragment) |
| 77 | + result = self.graphql_client.execute(mutation, payload) |
| 78 | + return result["createModel"] |
| 79 | + |
| 80 | + def update_properties_in_model(self, model_id: str, model: ModelToUpdateInput) -> Dict: |
| 81 | + """Send a GraphQL request calling updateModel resolver.""" |
| 82 | + payload = {"id": model_id, "input": map_update_model_input(model)} |
| 83 | + fragment = fragment_builder(["id"]) |
| 84 | + mutation = get_update_model_mutation(fragment) |
| 85 | + result = self.graphql_client.execute(mutation, payload) |
| 86 | + return result["updateModel"] |
| 87 | + |
| 88 | + def delete_model(self, model_id: str) -> Dict: |
| 89 | + """Send a GraphQL request to delete an organization model.""" |
| 90 | + payload = map_delete_model_input(model_id) |
| 91 | + mutation = get_delete_model_mutation() |
| 92 | + result = self.graphql_client.execute(mutation, payload) |
| 93 | + return result["deleteModel"] |
| 94 | + |
| 95 | + def create_project_model(self, project_model: ProjectModelToCreateInput) -> Dict: |
| 96 | + """Send a GraphQL request calling createModel resolver.""" |
| 97 | + payload = {"input": map_create_project_model_input(project_model)} |
| 98 | + fragment = fragment_builder(["id"]) |
| 99 | + mutation = get_create_project_model_mutation(fragment) |
| 100 | + result = self.graphql_client.execute(mutation, payload) |
| 101 | + return result["createProjectModel"] |
| 102 | + |
| 103 | + def update_project_model( |
| 104 | + self, project_model_id: str, project_model: ProjectModelToUpdateInput |
| 105 | + ) -> Dict: |
| 106 | + """Send a GraphQL request calling updateProjectModel resolver.""" |
| 107 | + payload = { |
| 108 | + "updateProjectModelId": project_model_id, |
| 109 | + "input": map_update_project_model_input(project_model), |
| 110 | + } |
| 111 | + fragment = fragment_builder(["id", "configuration"]) |
| 112 | + mutation = get_update_project_model_mutation(fragment) |
| 113 | + result = self.graphql_client.execute(mutation, payload) |
| 114 | + return result["updateProjectModel"] |
| 115 | + |
| 116 | + def delete_project_model(self, project_model_id: str) -> Dict: |
| 117 | + """Send a GraphQL request to delete a project model.""" |
| 118 | + payload = map_delete_project_model_input(project_model_id) |
| 119 | + mutation = get_delete_project_model_mutation() |
| 120 | + result = self.graphql_client.execute(mutation, payload) |
| 121 | + return result["deleteProjectModel"] |
| 122 | + |
| 123 | + def list_project_models( |
| 124 | + self, |
| 125 | + filters: ProjectModelFilters, |
| 126 | + fields: ListOrTuple[str], |
| 127 | + options: Optional[QueryOptions] = None, |
| 128 | + ) -> Generator[Dict, None, None]: |
| 129 | + """List project models with given options.""" |
| 130 | + fragment = fragment_builder(fields) |
| 131 | + query = get_project_models_query(fragment) |
| 132 | + where = project_model_where_mapper(filters) |
| 133 | + return PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call( |
| 134 | + query, |
| 135 | + where, |
| 136 | + options if options else QueryOptions(disable_tqdm=False), |
| 137 | + "Retrieving project models", |
| 138 | + None, |
| 139 | + ) |
0 commit comments