Skip to content

Commit 2205279

Browse files
authored
Merge pull request #163 from VoltaML/experimental
Version 0.4.0
2 parents 4babe57 + ded547d commit 2205279

File tree

238 files changed

+16001
-10214
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

238 files changed

+16001
-10214
lines changed

.dockerignore

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ pyproject.toml
1010
# User generated files
1111
/converted
1212
/data
13+
!/data/themes/dark.json
14+
!/data/themes/dark_flat.json
15+
!/data/themes/light.json
16+
!/data/themes/light_flat.json
1317
/engine
1418
/onnx
1519
/traced_unet
@@ -22,9 +26,6 @@ yarn.lock
2226
# Frontend
2327
frontend/dist/
2428

25-
# Static files
26-
/static
27-
2829
# Python
2930
/venv
3031

@@ -42,6 +43,7 @@ test.docker-compose.yml
4243

4344
# Other
4445
**/**.pyc
46+
poetry.lock
4547
.pytest_cache
4648
.coverage
4749
/.ruff_cache

.github/workflows/ruff.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
name: Ruff
2+
on: [push, pull_request]
3+
jobs:
4+
ruff:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- uses: actions/checkout@v3
8+
- uses: chartboost/ruff-action@v1

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ docs/.vitepress/dist
8080
external
8181
/tmp
8282
/data
83-
/data/settings.json
8483
/AITemplate
8584

8685
# Ignore for black

.vscode/settings.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
{
2-
"python.linting.pylintEnabled": true,
3-
"python.linting.enabled": true,
42
"python.testing.pytestArgs": ["."],
53
"python.testing.unittestEnabled": false,
64
"python.testing.pytestEnabled": true,
75
"python.analysis.typeCheckingMode": "basic",
8-
"python.formatting.provider": "black",
96
"python.languageServer": "Pylance",
107
"rust-analyzer.linkedProjects": ["./manager/Cargo.toml"]
118
}

api/app.py

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,21 @@
55
from pathlib import Path
66

77
from api_analytics.fastapi import Analytics
8-
from fastapi import Depends, FastAPI, Request
8+
from fastapi import Depends, FastAPI, Request, status
99
from fastapi.exceptions import RequestValidationError
1010
from fastapi.middleware.cors import CORSMiddleware
11-
from fastapi.responses import FileResponse
11+
from fastapi.responses import FileResponse, JSONResponse
1212
from fastapi.staticfiles import StaticFiles
1313
from fastapi_simple_cachecontrol.middleware import CacheControlMiddleware
1414
from fastapi_simple_cachecontrol.types import CacheControl
1515
from huggingface_hub.hf_api import LocalTokenNotFoundError
16-
from starlette import status
17-
from starlette.responses import JSONResponse
1816

1917
from api import websocket_manager
20-
from api.routes import (
21-
general,
22-
generate,
23-
hardware,
24-
models,
25-
outputs,
26-
settings,
27-
static,
28-
test,
29-
ws,
30-
)
18+
from api.routes import static, ws
3119
from api.websockets.data import Data
3220
from api.websockets.notification import Notification
3321
from core import shared
22+
from core.types import InferenceBackend
3423

3524
logger = logging.getLogger(__name__)
3625

@@ -100,50 +89,53 @@ async def hf_token_error(_request, _exc):
10089

10190

10291
@app.exception_handler(404)
103-
async def custom_http_exception_handler(_request, _exc):
92+
async def custom_http_exception_handler(request: Request, _exc):
10493
"Redirect back to the main page (frontend will handle it)"
10594

95+
if request.url.path.startswith("/api"):
96+
return JSONResponse(
97+
content={
98+
"status_code": 10404,
99+
"message": "Not Found",
100+
"data": None,
101+
},
102+
status_code=status.HTTP_404_NOT_FOUND,
103+
)
104+
106105
return FileResponse("frontend/dist/index.html")
107106

108107

109108
@app.on_event("startup")
110109
async def startup_event():
111110
"Prepare the event loop for other asynchronous tasks"
112111

113-
# Inject the logger
114-
from rich.logging import RichHandler
115-
116-
# Disable duplicate logger
117-
logging.getLogger("uvicorn").handlers = []
118-
119-
for logger_ in ("uvicorn.access", "uvicorn.error", "fastapi"):
120-
l = logging.getLogger(logger_)
121-
handler = RichHandler(
122-
rich_tracebacks=True, show_time=False, omit_repeated_times=False
123-
)
124-
handler.setFormatter(
125-
logging.Formatter(
126-
fmt="%(asctime)s | %(name)s » %(message)s", datefmt="%H:%M:%S"
127-
)
128-
)
129-
l.handlers = [handler]
130-
131112
if logger.level > logging.DEBUG:
132113
from transformers import logging as transformers_logging
133114

134115
transformers_logging.set_verbosity_error()
135116

136117
shared.asyncio_loop = asyncio.get_event_loop()
118+
websocket_manager.loop = shared.asyncio_loop
137119

138-
sync_task = asyncio.create_task(websocket_manager.sync_loop())
139-
logger.info("Started WebSocketManager sync loop")
140120
perf_task = asyncio.create_task(websocket_manager.perf_loop())
141-
142-
shared.asyncio_tasks.append(sync_task)
143121
shared.asyncio_tasks.append(perf_task)
144122

123+
from core.config import config
124+
125+
if config.api.autoloaded_models:
126+
from core.shared_dependent import cached_model_list, gpu
127+
128+
all_models = cached_model_list.all()
129+
130+
for model in config.api.autoloaded_models:
131+
if model in [i.path for i in all_models]:
132+
backend: InferenceBackend = [i.backend for i in all_models if i.path == model][0] # type: ignore
133+
await gpu.load_model(model, backend)
134+
else:
135+
logger.warning(f"Autoloaded model {model} not found, skipping")
136+
145137
logger.info("Started WebSocketManager performance monitoring loop")
146-
logger.info("UI Available at: http://localhost:5003/")
138+
logger.info(f"UI Available at: http://localhost:{shared.api_port}/")
147139

148140

149141
@app.on_event("shutdown")
@@ -165,42 +157,49 @@ async def shutdown_event():
165157
# Mount routers
166158
## HTTP
167159
app.include_router(static.router)
168-
app.include_router(test.router, prefix="/api/test")
169-
app.include_router(generate.router, prefix="/api/generate")
170-
app.include_router(hardware.router, prefix="/api/hardware")
171-
app.include_router(models.router, prefix="/api/models")
172-
app.include_router(outputs.router, prefix="/api/output")
173-
app.include_router(general.router, prefix="/api/general")
174-
app.include_router(settings.router, prefix="/api/settings")
160+
161+
# Walk the routes folder and mount all routers
162+
for file in Path("api/routes").iterdir():
163+
if file.is_file():
164+
if (
165+
file.name != "__init__.py"
166+
and file.suffix == ".py"
167+
and file.stem not in ["static", "ws"]
168+
):
169+
logger.debug(f"Mounting: {file} as /api/{file.stem}")
170+
module = __import__(f"api.routes.{file.stem}", fromlist=["router"])
171+
app.include_router(module.router, prefix=f"/api/{file.stem}")
175172

176173
## WebSockets
177174
app.include_router(ws.router, prefix="/api/websockets")
178175

179176
# Mount outputs folder
180-
output_folder = Path("data/outputs")
181-
output_folder.mkdir(exist_ok=True)
182177
app.mount("/data/outputs", StaticFiles(directory="data/outputs"), name="outputs")
183178

184179
# Mount static files (css, js, images, etc.)
185180
static_app = FastAPI()
186-
static_app.add_middleware(
187-
CORSMiddleware,
188-
allow_origins=["*"],
189-
allow_credentials=True,
190-
allow_methods=["*"],
191-
allow_headers=["*"],
192-
)
193181
static_app.add_middleware(
194182
CacheControlMiddleware, cache_control=CacheControl("no-cache")
195183
)
196184
static_app.mount("/", StaticFiles(directory="frontend/dist/assets"), name="assets")
197185

198186
app.mount("/assets", static_app)
187+
app.mount("/static", StaticFiles(directory="static"), name="extra_static_files")
188+
app.mount("/themes", StaticFiles(directory="data/themes"), name="themes")
189+
190+
origins = ["*"]
199191

200192
# Allow CORS for specified origins
201193
app.add_middleware(
202194
CORSMiddleware,
203-
allow_origins=["*"],
195+
allow_origins=origins,
196+
allow_credentials=True,
197+
allow_methods=["*"],
198+
allow_headers=["*"],
199+
)
200+
static_app.add_middleware(
201+
CORSMiddleware,
202+
allow_origins=origins,
204203
allow_credentials=True,
205204
allow_methods=["*"],
206205
allow_headers=["*"],

api/routes/autofill.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import List
4+
5+
from fastapi import APIRouter
6+
7+
router = APIRouter(tags=["autofill"])
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@router.get("/")
12+
def get_autofill_list() -> List[str]:
13+
"Gathers and returns all words from the prompt autofill files"
14+
15+
autofill_folder = Path("data/autofill")
16+
17+
words = []
18+
19+
logger.debug(f"Looking for autofill files in {autofill_folder}")
20+
logger.debug(f"Found {list(autofill_folder.iterdir())} files")
21+
22+
for file in autofill_folder.iterdir():
23+
if file.is_file():
24+
if file.suffix == ".txt":
25+
logger.debug(f"Found autofill file: {file}")
26+
with open(file, "r", encoding="utf-8") as f:
27+
words.extend(f.read().splitlines())
28+
29+
return list(set(words))

api/routes/general.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import sys
3+
from pathlib import Path
34

45
from fastapi import APIRouter
56

@@ -83,3 +84,16 @@ async def queue_clear():
8384
queue.clear()
8485

8586
return {"message": "Queue cleared"}
87+
88+
89+
@router.get("/themes")
90+
async def themes():
91+
"Get all available themes"
92+
93+
path = Path("data/themes")
94+
files = []
95+
for file in path.glob("*.json"):
96+
files.append(file.stem)
97+
98+
files.sort()
99+
return files

api/routes/generate.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ async def txt2img_job(job: Txt2ImgQueueEntry):
3535
time: float
3636
images, time = await gpu.generate(job)
3737
except ModelNotLoadedError:
38-
raise HTTPException( # pylint: disable=raise-missing-from
39-
status_code=400, detail="Model is not loaded"
40-
)
38+
raise HTTPException(status_code=400, detail="Model is not loaded")
4139

4240
return images_to_response(images, time)
4341

@@ -55,9 +53,7 @@ async def img2img_job(job: Img2ImgQueueEntry):
5553
time: float
5654
images, time = await gpu.generate(job)
5755
except ModelNotLoadedError:
58-
raise HTTPException( # pylint: disable=raise-missing-from
59-
status_code=400, detail="Model is not loaded"
60-
)
56+
raise HTTPException(status_code=400, detail="Model is not loaded")
6157

6258
return images_to_response(images, time)
6359

@@ -79,16 +75,14 @@ async def inpaint_job(job: InpaintQueueEntry):
7975
time: float
8076
images, time = await gpu.generate(job)
8177
except ModelNotLoadedError:
82-
raise HTTPException( # pylint: disable=raise-missing-from
83-
status_code=400, detail="Model is not loaded"
84-
)
78+
raise HTTPException(status_code=400, detail="Model is not loaded")
8579

8680
return images_to_response(images, time)
8781

8882

8983
@router.post("/controlnet")
9084
async def controlnet_job(job: ControlNetQueueEntry):
91-
"Generate variations of the image"
85+
"Generate images based on a reference image"
9286

9387
image_bytes = job.data.image
9488
assert isinstance(image_bytes, bytes)
@@ -99,9 +93,7 @@ async def controlnet_job(job: ControlNetQueueEntry):
9993
time: float
10094
images, time = await gpu.generate(job)
10195
except ModelNotLoadedError:
102-
raise HTTPException( # pylint: disable=raise-missing-from
103-
status_code=400, detail="Model is not loaded"
104-
)
96+
raise HTTPException(status_code=400, detail="Model is not loaded")
10597

10698
return images_to_response(images, time)
10799

@@ -119,9 +111,7 @@ async def realesrgan_upscale_job(job: UpscaleQueueEntry):
119111
time: float
120112
image, time = await gpu.upscale(job)
121113
except ModelNotLoadedError:
122-
raise HTTPException( # pylint: disable=raise-missing-from
123-
status_code=400, detail="Model is not loaded"
124-
)
114+
raise HTTPException(status_code=400, detail="Model is not loaded")
125115

126116
return {
127117
"time": time,
@@ -133,7 +123,7 @@ async def realesrgan_upscale_job(job: UpscaleQueueEntry):
133123

134124
@router.post("/generate-aitemplate")
135125
async def generate_aitemplate(request: AITemplateBuildRequest):
136-
"Generate a AITemplate model from a local model"
126+
"Generate an AITemplate model from a local model"
137127

138128
await gpu.build_aitemplate_engine(request)
139129

@@ -142,7 +132,7 @@ async def generate_aitemplate(request: AITemplateBuildRequest):
142132

143133
@router.post("/generate-dynamic-aitemplate")
144134
async def generate_dynamic_aitemplate(request: AITemplateDynamicBuildRequest):
145-
"Generate a AITemplate engine from a local model"
135+
"Generate an AITemplate engine from a local model"
146136

147137
await gpu.build_dynamic_aitemplate_engine(request)
148138

api/routes/hardware.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ async def gpu_memory(gpu_id: int):
4545
gpu_data = GPUStatCollection.new_query().gpus[gpu_id]
4646
return (gpu_data.memory_total, gpu_data.memory_free, "MB")
4747
except IndexError:
48-
raise HTTPException( # pylint: disable=raise-missing-from
49-
status_code=400, detail="GPU not found"
50-
)
48+
raise HTTPException(status_code=400, detail="GPU not found")
5149

5250

5351
@router.get("/capabilities")

0 commit comments

Comments
 (0)