1
- from fastapi import FastAPI
2
- import torch
3
- from transformers import AutoModelForCausalLM , AutoTokenizer , pipeline
1
+ '''
2
+ ##################### TinyLlama + FastAPI + Docker #########################################
3
+ Autor: Santiago Gonzalez Acevedo
4
+ Twitter: @locoalien
5
+ Python 3.11+
6
+ '''
7
+ #https://medium.com/@santiagosk80/tinyllama-fastapi-docker-microservicios-llm-ff99eb999f04
8
+ import logging
9
+ import os
10
+ import torch
11
+ from fastapi import FastAPI , HTTPException
12
+ from transformers import pipeline
13
+ import docs #Libreria con informacion de la API en Swagger
14
+ from starlette .middleware .cors import CORSMiddleware #Seguridad a nivel de CORS
15
+ import json
4
16
5
- app = FastAPI ()
17
+ logger = logging .getLogger (__name__ )
18
+ # Crea una instancia de FastAPI
19
+ app = FastAPI (title = 'LLM Chat Service' , description = docs .desc , version = docs .version )
20
+ # CORS Configuration (in-case you want to deploy)
21
+ app .add_middleware (
22
+ CORSMiddleware ,
23
+ allow_origins = ["*" ],
24
+ allow_credentials = True ,
25
+ allow_methods = ["GET" , "POST" , "OPTIONS" ],
26
+ allow_headers = ["*" ],
27
+ )
28
+ logger .info ('Adding v1 endpoints..' )
6
29
7
- torch .random .manual_seed (0 )
8
- model = AutoModelForCausalLM .from_pretrained (
9
- "microsoft/Phi-3-mini-4k-instruct" ,
10
- device_map = "cuda" ,
11
- torch_dtype = "auto" ,
12
- trust_remote_code = True ,
13
- )
30
+ # Carga el modelo y el tokenizador
31
+ pipe = pipeline ("text-generation" , model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch .bfloat16 , device_map = "auto" )
14
32
15
- tokenizer = AutoTokenizer .from_pretrained ("microsoft/Phi-3-mini-4k-instruct" )
16
-
17
- messages = [
18
- {"role" : "system" , "content" : "You are a helpful AI assistant." },
19
- {"role" : "user" , "content" : "Can you provide ways to eat combinations of bananas and dragonfruits?" },
20
- {"role" : "assistant" , "content" : "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey." },
21
- {"role" : "user" , "content" : "What about solving an 2x + 3 = 7 equation?" },
22
- ]
23
-
24
- pipe = pipeline (
25
- "text-generation" ,
26
- model = model ,
27
- tokenizer = tokenizer ,
28
- )
29
-
30
- generation_args = {
31
- "max_new_tokens" : 500 ,
32
- "return_full_text" : False ,
33
- "temperature" : 0.0 ,
34
- "do_sample" : False ,
35
- }
36
-
37
- @app .post ("/predict" )
38
- async def predict (text : str ):
39
- output = pipe (messages , ** generation_args )
40
- return {"prediction" : output [0 ]['generated_text' ]}
33
+ # Necesito un enpoint "/chat" que reciba un texto, lo pase por el modelo y devuelva la respuesta
34
+ @app .post ("/chat" )
35
+ async def chat (text : str ):
36
+ try :
37
+ #Configuracion de comportamiento del modelo
38
+ messages = [
39
+ {
40
+ "role" : "system" ,
41
+ "content" : "Solo quiero la respuesta a la pregunta sin repetir la pregunta, por favor." ,
42
+ },
43
+ {"role" : "user" , "content" : f"{ text } " },
44
+ ]
45
+ #Obtener prompt para el modelo
46
+ prompt = pipe .tokenizer .apply_chat_template (
47
+ messages , tokenize = False , add_generation_prompt = True
48
+ )
49
+ #Configuracion de exactitud del modelo
50
+ outputs = pipe (
51
+ prompt ,
52
+ max_new_tokens = 256 ,
53
+ do_sample = True ,
54
+ temperature = 0.3 ,
55
+ top_k = 50 ,
56
+ top_p = 0.95 ,
57
+ )
58
+ #Resultado del modelo
59
+ output = outputs [0 ]["generated_text" ]
60
+ # Extraer la parte de la respuesta a partir de "<|assistant|>"
61
+ assistant_response = output .split ("<|assistant|>" )[- 1 ].strip ()
62
+ json_results = json_results = json .dumps ({"response" : assistant_response }, ensure_ascii = False , indent = 4 ).encode ('utf8' )
63
+ return json .loads (json_results )
64
+ except Exception as e :
65
+ logger .error (f'Error: { e } ' )
66
+ raise HTTPException (status_code = 500 , detail = 'Internal Server Error' )
67
+
68
+ # Ejecutar el servidor con uvicorn
69
+ if __name__ == "__main__" :
70
+ import uvicorn
71
+ uvicorn .run (app )
0 commit comments