|
1 |
| -import glob |
2 |
| -import json |
3 |
| -import os |
4 |
| -import uuid |
5 |
| -import random |
6 |
| -from typing import List |
7 |
| -import time |
| 1 | +from __future__ import annotations |
| 2 | +import os, time, random, json, uuid, glob, torch, traceback |
| 3 | + |
| 4 | +from enum import Enum |
| 5 | +from typing import List, Tuple |
8 | 6 | from model import Model
|
9 | 7 | from datetime import datetime
|
10 | 8 | from joblib import Parallel, delayed
|
11 |
| -from flask import Blueprint, request, Response, redirect |
12 |
| -import torch |
13 |
| - |
| 9 | +from flask import Blueprint, request, Response, redirect, current_app |
14 | 10 | from limiter import limiter
|
15 | 11 |
|
| 12 | +from user_study import ( |
| 13 | + filter_request, |
| 14 | + store_completion_request, |
| 15 | + should_prompt_survey, |
| 16 | + USER_STUDY_DIR, |
| 17 | +) |
| 18 | + |
16 | 19 | v1 = Blueprint("v1", __name__)
|
| 20 | +v2 = Blueprint("v2", __name__) |
17 | 21 |
|
18 | 22 | os.makedirs("data", exist_ok=True)
|
19 | 23 |
|
| 24 | +def authorise(req) -> str: |
| 25 | + ''' Authorise the request. Raise ValueError if the request is not authorised. ''' |
| 26 | + |
| 27 | + auth = req.authorization.token |
| 28 | + if auth is None: |
| 29 | + raise ValueError("Missing bearer token") |
| 30 | + return auth |
| 31 | + |
| 32 | +def get_predictions(completion_request: dict) -> Tuple[float, dict[str, str]]: |
| 33 | + ''' Return a list of predictions. ''' |
| 34 | + |
| 35 | + prefix = completion_request['prefix'].rstrip() |
| 36 | + suffix = completion_request['suffix'] |
| 37 | + |
| 38 | + def predict_model(model: Model) -> str: |
| 39 | + try: |
| 40 | + return model.value[1](prefix, suffix)[0] |
| 41 | + except torch.cuda.OutOfMemoryError: |
| 42 | + exit(1) |
| 43 | + |
| 44 | + t0 = datetime.now() |
| 45 | + predictions = Parallel(n_jobs=os.cpu_count(), prefer="threads")(delayed(predict_model)(model) for model in Model) |
| 46 | + time = (datetime.now() - t0).total_seconds() * 1000 |
| 47 | + |
| 48 | + predictions = {model.name: prediction for model, prediction in zip(Model, predictions)} |
| 49 | + return time, predictions |
| 50 | + |
| 51 | +@v2.route("/prediction/autocomplete", methods=["POST"]) |
| 52 | +@limiter.limit("4000/hour") |
| 53 | +def autocomplete_v2(): |
| 54 | + |
| 55 | + try: |
| 56 | + # TODO: As we want every request to be authorised, this can be extracted into a decorator |
| 57 | + user_uuid = authorise(request) |
| 58 | + request_json = request.json |
| 59 | + |
| 60 | + # TODO: add a None filter type for baseline comparison |
| 61 | + filter_time, filter_type, should_filter = filter_request(user_uuid, request_json) |
| 62 | + |
| 63 | + predict_time, predictions = get_predictions(request_json) \ |
| 64 | + if (not should_filter) or (request_json['trigger'] == 'manual') \ |
| 65 | + else (None, {}) |
| 66 | + |
| 67 | + log_filter = f'\033[1m{"filter" if should_filter else "predict"}\033[0m' |
| 68 | + log_context = f'{request_json["prefix"][-10:]}•{request_json["suffix"][:5]}' |
| 69 | + current_app.logger.warning(f'{log_filter} {log_context} \t{filter_type} {[v[:10] for v in predictions.values()]}') |
| 70 | + |
| 71 | + verify_token = uuid.uuid4().hex if not should_filter else '' |
| 72 | + prompt_survey = should_prompt_survey(user_uuid) if not should_filter else False |
| 73 | + |
| 74 | + store_completion_request(user_uuid, verify_token, { |
| 75 | + **request_json, |
| 76 | + 'timestamp': datetime.now().isoformat(), |
| 77 | + 'filter_type': filter_type, |
| 78 | + 'filter_time': filter_time, |
| 79 | + 'should_filter': should_filter, |
| 80 | + 'predict_time': predict_time, |
| 81 | + 'predictions': predictions, |
| 82 | + 'survey': prompt_survey, |
| 83 | + 'study_version': '0.0.1' |
| 84 | + }) |
| 85 | + |
| 86 | + return { |
| 87 | + 'predictions': predictions, |
| 88 | + 'verifyToken': verify_token, |
| 89 | + 'survey': prompt_survey |
| 90 | + } |
| 91 | + |
| 92 | + except Exception as e: |
| 93 | + |
| 94 | + error_uuid = uuid.uuid4().hex |
| 95 | + current_app.logger.warning(f''' |
| 96 | + Error {error_uuid} for {user_uuid if user_uuid is not None else "unauthenticated user"} |
| 97 | + {request.json if request.is_json else "no request json found"} |
| 98 | + ''') |
| 99 | + traceback.print_exc() |
| 100 | + |
| 101 | + return response({ "error": error_uuid }, status=400) |
| 102 | + |
| 103 | +@v2.route("/prediction/verify", methods=["POST"]) |
| 104 | +@limiter.limit("4000/hour") |
| 105 | +def verify_v2(): |
| 106 | + |
| 107 | + user_uuid = authorise(request) |
| 108 | + verify_json = request.json |
| 109 | + |
| 110 | + # current_app.logger.info(verify_json) |
| 111 | + |
| 112 | + verify_token = verify_json['verifyToken'] |
| 113 | + file_path = os.path.join(USER_STUDY_DIR, user_uuid, f'{verify_token}.json') |
| 114 | + |
| 115 | + with open(file_path, 'r+') as completion_file: |
| 116 | + completion_json = json.load(completion_file) |
| 117 | + |
| 118 | + if 'ground_truth' in completion_json: |
| 119 | + return response({ |
| 120 | + "error": "Already used verify token" |
| 121 | + }, status=400) |
| 122 | + |
| 123 | + completion_json.update(verify_json) |
| 124 | + |
| 125 | + completion_file.seek(0) |
| 126 | + completion_file.write(json.dumps(completion_json)) |
| 127 | + completion_file.truncate() |
| 128 | + |
| 129 | + return response({'success': True}) |
| 130 | + |
| 131 | + |
| 132 | +##### NOTE: OLD IMPLEMENTATION KEPT FOR JETBRAINS USERS #### |
| 133 | +# (and, those that have turned of auto-update for vsc extensions) |
20 | 134 |
|
21 | 135 | @v1.route("/prediction/autocomplete", methods=["POST"])
|
22 | 136 | @limiter.limit("1000/hour")
|
@@ -85,8 +199,10 @@ def predict_model(model: Model) -> List[str]:
|
85 | 199 | "rightContext": right_context if store_context else None
|
86 | 200 | }))
|
87 | 201 |
|
88 |
| - n_suggestions = len(glob.glob(f"data/{user_token}*.json")) |
89 |
| - survey = n_suggestions >= 100 and n_suggestions % 50 == 0 |
| 202 | + # # # TODO: disabled surveys temporarily, as we are currently looking through >1M files on every request. |
| 203 | + # n_suggestions = len(glob.glob(f"data/{user_token}*.json")) |
| 204 | + # survey = n_suggestions >= 100 and n_suggestions % 50 == 0 |
| 205 | + survey = False |
90 | 206 |
|
91 | 207 | return response({
|
92 | 208 | "predictions": unique_predictions,
|
|
0 commit comments