Skip to content

Commit 620bfa3

Browse files
authored
Merge pull request #38 from code4me-me/aral_user_study
Aral's User Study
2 parents e723b21 + 1b0af53 commit 620bfa3

File tree

15 files changed

+2907
-596
lines changed

15 files changed

+2907
-596
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ build
55
venv
66
__pycache__
77
users*.json
8+
data_aral
89

10+
models

code4me-server/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ nltk~=3.8.1
88
datasets~=2.9.0
99
markdown~=3.4.1
1010
joblib~=1.2.0
11+
safetensors

code4me-server/src/api.py

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,136 @@
1-
import glob
2-
import json
3-
import os
4-
import uuid
5-
import random
6-
from typing import List
7-
import time
1+
from __future__ import annotations
2+
import os, time, random, json, uuid, glob, torch, traceback
3+
4+
from enum import Enum
5+
from typing import List, Tuple
86
from model import Model
97
from datetime import datetime
108
from joblib import Parallel, delayed
11-
from flask import Blueprint, request, Response, redirect
12-
import torch
13-
9+
from flask import Blueprint, request, Response, redirect, current_app
1410
from limiter import limiter
1511

12+
from user_study import (
13+
filter_request,
14+
store_completion_request,
15+
should_prompt_survey,
16+
USER_STUDY_DIR,
17+
)
18+
1619
v1 = Blueprint("v1", __name__)
20+
v2 = Blueprint("v2", __name__)
1721

1822
os.makedirs("data", exist_ok=True)
1923

24+
def authorise(req) -> str:
25+
''' Authorise the request. Raise ValueError if the request is not authorised. '''
26+
27+
auth = req.authorization.token
28+
if auth is None:
29+
raise ValueError("Missing bearer token")
30+
return auth
31+
32+
def get_predictions(completion_request: dict) -> Tuple[float, dict[str, str]]:
33+
''' Return a list of predictions. '''
34+
35+
prefix = completion_request['prefix'].rstrip()
36+
suffix = completion_request['suffix']
37+
38+
def predict_model(model: Model) -> str:
39+
try:
40+
return model.value[1](prefix, suffix)[0]
41+
except torch.cuda.OutOfMemoryError:
42+
exit(1)
43+
44+
t0 = datetime.now()
45+
predictions = Parallel(n_jobs=os.cpu_count(), prefer="threads")(delayed(predict_model)(model) for model in Model)
46+
time = (datetime.now() - t0).total_seconds() * 1000
47+
48+
predictions = {model.name: prediction for model, prediction in zip(Model, predictions)}
49+
return time, predictions
50+
51+
@v2.route("/prediction/autocomplete", methods=["POST"])
52+
@limiter.limit("4000/hour")
53+
def autocomplete_v2():
54+
55+
try:
56+
# TODO: As we want every request to be authorised, this can be extracted into a decorator
57+
user_uuid = authorise(request)
58+
request_json = request.json
59+
60+
# TODO: add a None filter type for baseline comparison
61+
filter_time, filter_type, should_filter = filter_request(user_uuid, request_json)
62+
63+
predict_time, predictions = get_predictions(request_json) \
64+
if (not should_filter) or (request_json['trigger'] == 'manual') \
65+
else (None, {})
66+
67+
log_filter = f'\033[1m{"filter" if should_filter else "predict"}\033[0m'
68+
log_context = f'{request_json["prefix"][-10:]}{request_json["suffix"][:5]}'
69+
current_app.logger.warning(f'{log_filter} {log_context} \t{filter_type} {[v[:10] for v in predictions.values()]}')
70+
71+
verify_token = uuid.uuid4().hex if not should_filter else ''
72+
prompt_survey = should_prompt_survey(user_uuid) if not should_filter else False
73+
74+
store_completion_request(user_uuid, verify_token, {
75+
**request_json,
76+
'timestamp': datetime.now().isoformat(),
77+
'filter_type': filter_type,
78+
'filter_time': filter_time,
79+
'should_filter': should_filter,
80+
'predict_time': predict_time,
81+
'predictions': predictions,
82+
'survey': prompt_survey,
83+
'study_version': '0.0.1'
84+
})
85+
86+
return {
87+
'predictions': predictions,
88+
'verifyToken': verify_token,
89+
'survey': prompt_survey
90+
}
91+
92+
except Exception as e:
93+
94+
error_uuid = uuid.uuid4().hex
95+
current_app.logger.warning(f'''
96+
Error {error_uuid} for {user_uuid if user_uuid is not None else "unauthenticated user"}
97+
{request.json if request.is_json else "no request json found"}
98+
''')
99+
traceback.print_exc()
100+
101+
return response({ "error": error_uuid }, status=400)
102+
103+
@v2.route("/prediction/verify", methods=["POST"])
104+
@limiter.limit("4000/hour")
105+
def verify_v2():
106+
107+
user_uuid = authorise(request)
108+
verify_json = request.json
109+
110+
# current_app.logger.info(verify_json)
111+
112+
verify_token = verify_json['verifyToken']
113+
file_path = os.path.join(USER_STUDY_DIR, user_uuid, f'{verify_token}.json')
114+
115+
with open(file_path, 'r+') as completion_file:
116+
completion_json = json.load(completion_file)
117+
118+
if 'ground_truth' in completion_json:
119+
return response({
120+
"error": "Already used verify token"
121+
}, status=400)
122+
123+
completion_json.update(verify_json)
124+
125+
completion_file.seek(0)
126+
completion_file.write(json.dumps(completion_json))
127+
completion_file.truncate()
128+
129+
return response({'success': True})
130+
131+
132+
##### NOTE: OLD IMPLEMENTATION KEPT FOR JETBRAINS USERS ####
133+
# (and, those that have turned of auto-update for vsc extensions)
20134

21135
@v1.route("/prediction/autocomplete", methods=["POST"])
22136
@limiter.limit("1000/hour")
@@ -85,8 +199,10 @@ def predict_model(model: Model) -> List[str]:
85199
"rightContext": right_context if store_context else None
86200
}))
87201

88-
n_suggestions = len(glob.glob(f"data/{user_token}*.json"))
89-
survey = n_suggestions >= 100 and n_suggestions % 50 == 0
202+
# # # TODO: disabled surveys temporarily, as we are currently looking through >1M files on every request.
203+
# n_suggestions = len(glob.glob(f"data/{user_token}*.json"))
204+
# survey = n_suggestions >= 100 and n_suggestions % 50 == 0
205+
survey = False
90206

91207
return response({
92208
"predictions": unique_predictions,

code4me-server/src/app.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
from pathlib import Path
1+
import markdown, os
22

3-
import markdown
3+
from pathlib import Path
44
from flask import Flask, jsonify, render_template
5-
from api import v1
5+
from api import v1, v2
66
from limiter import limiter
77

88
app = Flask(__name__, static_folder="static", template_folder="templates")
99
limiter.init_app(app)
1010
app.register_blueprint(v1, url_prefix='/api/v1')
11+
app.register_blueprint(v2, url_prefix='/api/v2')
1112

12-
index_md = markdown.markdown(Path("markdowns/index.md").read_text())
13+
markdown_path = 'markdowns/index.md'
14+
index_md = markdown.markdown(Path(markdown_path).read_text())
1315

1416

1517
@app.errorhandler(429)

code4me-server/src/codegpt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import os
44
import torch
55

6+
# env variable for local testing
7+
CODE4ME_TEST = os.environ.get("CODE4ME_TEST", "false") == "true"
8+
69
checkpoint_path = "gpt2" # default checkpoint is the non-finetuned gpt2 model
710

811
# if CODEGPT_CHECKPOINT_PATH is set, use that checkpoint
912
if os.environ.get("CODEGPT_CHECKPOINT_PATH"):
1013
checkpoint_path = os.environ.get("CODEGPT_CHECKPOINT_PATH")
1114

12-
if not os.path.exists(checkpoint_path):
15+
if not os.path.exists(checkpoint_path) and not CODE4ME_TEST:
1316
raise ValueError(f"Invalid checkpoint path: '{checkpoint_path}'")
1417

1518
config = GPT2Config
@@ -23,7 +26,7 @@
2326
class Beam(object):
2427
def __init__(self, size, sos, eos):
2528
self.size = size
26-
self.tt = torch.cuda
29+
self.tt = torch.cuda if torch.cuda.is_available() else torch
2730
# The score for each translation on the beam.
2831
self.scores = self.tt.FloatTensor(size).zero_().to(device)
2932
# The backpointers at each time-step.
@@ -159,7 +162,8 @@ def DecodeIds(idxs):
159162
break_ids = [tokenizer.sep_token_id]
160163

161164
m = torch.nn.LogSoftmax(dim=-1).to(device)
162-
zero = torch.cuda.LongTensor(1).fill_(0).to(device)
165+
# I presume the .cuda. is not necessary here if it is moved to the CUDA device immediately, but not risking it.
166+
zero = torch.cuda.LongTensor(1).fill_(0).to(device) if not CODE4ME_TEST else torch.LongTensor(1).fill_(0).to(device)
163167

164168
def codegpt_predict(left_context: str, right_context: str) -> List[str]:
165169
left_context = left_context.replace("\n", "<EOL>")
@@ -179,7 +183,7 @@ def codegpt_predict(left_context: str, right_context: str) -> List[str]:
179183
inputs = torch.tensor(tokens, device=device).unsqueeze(0)
180184
with torch.no_grad():
181185
beam_size = 1
182-
outputs = model(inputs[:, :-1])[1]
186+
outputs = model(inputs)[1]
183187
p = []
184188
for i in range(inputs.shape[0]):
185189
past = [torch.cat([x[0].unsqueeze(0), x[1].unsqueeze(0)], dim=0) if type(x) == tuple else x for x in

code4me-server/src/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,25 @@
1+
import os
12
from enum import Enum
3+
from typing import Callable
4+
5+
# NOTE: Convenient for testing, use preset generate functions
6+
# if os.getenv("CODE4ME_TEST", "false") == "true":
7+
# print('''
8+
# \033[1m WARNING: RUNNING IN TEST MODE \033[0m
9+
# ''')
10+
# # if the env variable TEST_MODE is set to True, then remap model.generate to lambda: 'model_name'
11+
12+
# incoder = type("InCoder", (object,), {})
13+
# unixcoder_wrapper = type("UniXCoder", (object,), {})
14+
# import codegpt
15+
# # codegpt = type("CodeGPT", (object,), {})
16+
17+
# incoder.generate = lambda left, right: ['predict_incoder']
18+
# unixcoder_wrapper.generate = lambda left, right: [' predict_unixcoder']
19+
20+
# # codegpt.codegpt_predict = lambda left, right: [' (predict_codegpt']
21+
# else:
22+
# # ooh yeah, import statements in an else stmt; i see new things every day
223
import incoder
324
import unixcoder_wrapper
425
import codegpt

0 commit comments

Comments
 (0)