Skip to content

Commit 11b77fd

Browse files
Stax124gabe56faaronsantiagoMiningPsamchouse
authored
v0.2 (#88)
Merge features from the experimental branch into the main branch (stable) --------- Co-authored-by: Márton Kissik <[email protected]> Co-authored-by: aaronsantiago <[email protected]> Co-authored-by: MiningP <[email protected]> Co-authored-by: Samuel Corsi-House <[email protected]>
1 parent a1248ec commit 11b77fd

File tree

210 files changed

+28058
-7628
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

210 files changed

+28058
-7628
lines changed

.github/workflows/yarn_build.yml renamed to .github/workflows/frontend_build.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ name: Frontend Build Test
66
on:
77
push:
88
branches: [main, experimental]
9+
paths:
10+
- "frontend/**"
11+
- ".github/workflows/*.yml"
912
pull_request:
1013
branches: [main, experimental]
1114

@@ -24,10 +27,10 @@ jobs:
2427
- name: Run install
2528
uses: borales/actions-yarn@v4
2629
with:
27-
cmd: install # will run `yarn install` command
28-
dir: "frontend" # will run `yarn install` in `frontend` sub folder
30+
cmd: install
31+
dir: "frontend"
2932
- name: Build production bundle
3033
uses: borales/actions-yarn@v4
3134
with:
32-
cmd: build # will run `yarn build:prod` command
33-
dir: "frontend" # will run `yarn build` in `frontend` sub folder
35+
cmd: build
36+
dir: "frontend"

.github/workflows/manager_build.yml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,37 @@ on:
55
branches: ["main", "experimental"]
66
paths:
77
- "**.rs"
8+
- ".github/workflows/**"
89
pull_request:
910
branches: ["main", "experimental"]
1011
paths:
1112
- "**.rs"
13+
- ".github/workflows/**"
14+
workflow_dispatch:
15+
inputs:
16+
branch:
17+
description: "Branch to build"
18+
required: true
19+
default: "main"
1220

1321
env:
1422
CARGO_TERM_COLOR: always
1523

1624
jobs:
1725
build:
18-
runs-on: ubuntu-latest
26+
strategy:
27+
matrix:
28+
os: [ubuntu-latest, windows-latest]
29+
30+
runs-on: ${{ matrix.os }}
1931

2032
steps:
2133
- uses: actions/checkout@v3
2234
- name: Build
23-
run: cd manager && cargo build --verbose --release
35+
run: cd manager && cargo build --release
2436
- name: Output binary
2537
uses: actions/upload-artifact@v3
2638
with:
2739
name: volta-manager
28-
path: manager/target/release/voltaml-manager
40+
path: manager/target/release/voltaml-manager*
2941
retention-days: 3

.gitignore

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ docs/.vitepress/dist
7474
/engine
7575
/static/output
7676
/outputs
77-
/frontend/yarn.lock
7877
/testing
7978
/typings
8079
external
@@ -88,3 +87,11 @@ core/submodules
8887

8988
# Manager
9089
manager/target
90+
voltaml-manager
91+
voltaml-manager.exe
92+
93+
# Profiler stuff
94+
profile.html
95+
96+
# dotenv file
97+
.env

README.md

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
</div>
5454

5555
<hr>
56-
<h3 align="center">Made with ❤️ by <a href="https://github.com/Stax124/">Stax124</a></h3>
56+
<h3 align="center">Made with ❤️ by <a href="https://github.com/Stax124/">Stax124</a> and the community</h3>
5757
<hr>
5858

5959
<br />
@@ -67,7 +67,6 @@
6767
- [Speed comparison](#speed-comparison)
6868
- [Installation](#installation)
6969
- [Contributing](#contributing)
70-
- [Code of Conduct](#code-of-conduct)
7170
- [License](#license)
7271
- [Contact](#contact)
7372

@@ -136,17 +135,11 @@
136135

137136
## Speed comparison
138137

139-
The below benchmarks have been done for generating a 512x512 image, batch size 1 for 50 iterations.
140-
141-
| GPU (it/s) | T4 | A10 | A100 | 4090 | 4080 | 3090 | 2080Ti | 3050 |
142-
| ---------- | --- | ---- | ---- | ---- | ---- | ---- | ------ | ---- |
143-
| PyTorch | 4.3 | 8.8 | 15.1 | 19 | 15.5 | 11 | 8 | 4.1 |
144-
| xFormers | 5.5 | 15.6 | 27.5 | 28 | 20.2 | 15.7 | N/A | 5.1 |
145-
| AITemplate | N/A | 23 | N/A | N/A | 40.5 | N/A | N/A | 10.2 |
138+
Please refer to this [table](https://voltaml.github.io/voltaML-fast-stable-diffusion/getting-started/introduction#speed-comparison). Data had a small sample size and was usually collected on a single machine. Your results may vary.
146139

147140
## Installation
148141

149-
Please see the [documentation](https://voltaml.github.io/voltaML-fast-stable-diffusion/installation/docker.html) for installation instructions.
142+
Please see the [documentation](https://voltaml.github.io/voltaML-fast-stable-diffusion/installation/local) for installation instructions.
150143

151144
# Contributing
152145

@@ -158,10 +151,6 @@ Contributions are always welcome!
158151

159152
See `contributing.md` for ways to get started.
160153

161-
## Code of Conduct
162-
163-
Please read the [Code of Conduct](https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/master/CODE_OF_CONDUCT.md)
164-
165154
# License
166155

167156
Distributed under the <b>GPL v3</b>. See [License](https://github.com/VoltaML/voltaML-fast-stable-diffusion/blob/experimental/License) for more information.

api/app.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,31 @@ async def custom_http_exception_handler(_request, _exc):
8989
async def startup_event():
9090
"Prepare the event loop for other asynchronous tasks"
9191

92+
# Inject the logger
93+
from rich.logging import RichHandler
94+
95+
# Disable duplicate logger
96+
logging.getLogger("uvicorn").handlers = []
97+
98+
for logger_ in ("uvicorn.access", "uvicorn.error", "fastapi"):
99+
l = logging.getLogger(logger_)
100+
handler = RichHandler(rich_tracebacks=True, show_time=False)
101+
handler.setFormatter(
102+
logging.Formatter(
103+
fmt="%(asctime)s | %(name)s » %(message)s", datefmt="%H:%M:%S"
104+
)
105+
)
106+
l.handlers = [handler]
107+
92108
shared.asyncio_loop = asyncio.get_event_loop()
93109

94-
asyncio.create_task(websocket_manager.sync_loop())
110+
sync_task = asyncio.create_task(websocket_manager.sync_loop())
95111
logger.info("Started WebSocketManager sync loop")
96-
asyncio.create_task(websocket_manager.perf_loop())
112+
perf_task = asyncio.create_task(websocket_manager.perf_loop())
113+
114+
shared.asyncio_tasks.append(sync_task)
115+
shared.asyncio_tasks.append(perf_task)
116+
97117
logger.info("Started WebSocketManager performance monitoring loop")
98118
logger.info("UI Available at: http://localhost:5003/")
99119

@@ -106,18 +126,6 @@ async def shutdown_event():
106126
await websocket_manager.close_all()
107127

108128

109-
# Origins that are allowed to access the API
110-
origins = ["*"]
111-
112-
# Allow CORS for specified origins
113-
app.add_middleware(
114-
CORSMiddleware,
115-
allow_origins=origins,
116-
allow_credentials=True,
117-
allow_methods=["*"],
118-
allow_headers=["*"],
119-
)
120-
121129
# Enable FastAPI Analytics if key is provided
122130
key = os.getenv("FASTAPI_ANALYTICS_KEY")
123131
if key:
@@ -149,7 +157,7 @@ async def shutdown_event():
149157
static_app = FastAPI()
150158
static_app.add_middleware(
151159
CORSMiddleware,
152-
allow_origins=origins,
160+
allow_origins=["*"],
153161
allow_credentials=True,
154162
allow_methods=["*"],
155163
allow_headers=["*"],
@@ -160,3 +168,12 @@ async def shutdown_event():
160168
static_app.mount("/", StaticFiles(directory="frontend/dist/assets"), name="assets")
161169

162170
app.mount("/assets", static_app)
171+
172+
# Allow CORS for specified origins
173+
app.add_middleware(
174+
CORSMiddleware,
175+
allow_origins=["*"],
176+
allow_credentials=True,
177+
allow_methods=["*"],
178+
allow_headers=["*"],
179+
)

api/routes/general.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import logging
2+
import sys
3+
14
from fastapi import APIRouter
25

6+
from api import websocket_manager
7+
from api.websockets.notification import Notification
38
from core import shared
49

510
router = APIRouter(tags=["general"])
11+
logger = logging.getLogger(__name__)
612

713

814
@router.post("/interrupt")
@@ -11,3 +17,71 @@ async def interrupt():
1117

1218
shared.interrupt = True
1319
return {"message": "Interupted"}
20+
21+
22+
@router.post("/shutdown")
23+
async def shutdown():
24+
"Shutdown the server"
25+
26+
from core.config import config
27+
from core.shared import uvicorn_loop, uvicorn_server
28+
29+
if config.api.enable_shutdown:
30+
if uvicorn_server is not None:
31+
await websocket_manager.broadcast(
32+
data=Notification(
33+
message="Shutting down the server",
34+
severity="warning",
35+
title="Shutdown",
36+
)
37+
)
38+
for task in shared.asyncio_tasks:
39+
task.cancel()
40+
uvicorn_server.force_exit = True
41+
logger.debug("Setting force_exit to True")
42+
43+
assert uvicorn_server is not None
44+
await uvicorn_server.shutdown()
45+
logger.debug("Unicorn server shutdown")
46+
47+
assert uvicorn_loop is not None
48+
uvicorn_loop.stop()
49+
logger.debug("Unicorn loop stopped")
50+
51+
sys.exit(0)
52+
53+
else:
54+
await websocket_manager.broadcast(
55+
data=Notification(
56+
message="Shutdown is disabled", severity="error", title="Shutdown"
57+
)
58+
)
59+
return {"message": "Shutdown is disabled"}
60+
61+
62+
@router.get("/queue-status")
63+
async def queue_status():
64+
"Get the status of the queue"
65+
66+
from core.shared_dependent import gpu
67+
68+
queue = gpu.queue
69+
70+
return {
71+
"jobs": queue.jobs,
72+
"concurrent_jobs": queue.concurrent_jobs,
73+
"locked": queue.lock.locked(),
74+
}
75+
76+
77+
@router.post("/queue-clear")
78+
async def queue_clear():
79+
"Clear the queue"
80+
81+
from core.shared_dependent import gpu
82+
83+
queue = gpu.queue
84+
85+
queue.clear()
86+
87+
return {"message": "Queue cleared"}

api/routes/generate.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88
from core.shared_dependent import gpu
99
from core.types import (
1010
AITemplateBuildRequest,
11+
AITemplateDynamicBuildRequest,
1112
ControlNetQueueEntry,
1213
ConvertModelRequest,
1314
Img2ImgQueueEntry,
1415
InpaintQueueEntry,
1516
InterrogatorQueueEntry,
1617
ONNXBuildRequest,
17-
RealESRGANQueueEntry,
1818
SDUpscaleQueueEntry,
1919
TRTBuildRequest,
2020
Txt2ImgQueueEntry,
21+
UpscaleQueueEntry,
2122
)
2223
from core.utils import convert_bytes_to_image_stream, convert_image_to_base64
2324

24-
router = APIRouter(tags=["txt2img"])
25+
router = APIRouter(tags=["generate"])
2526
logger = logging.getLogger(__name__)
2627

2728

@@ -195,38 +196,27 @@ async def sd_upscale_job(job: SDUpscaleQueueEntry):
195196
}
196197

197198

198-
@router.post("/realesrgan-upscale")
199-
async def realesrgan_upscale_job(job: RealESRGANQueueEntry):
199+
@router.post("/upscale")
200+
async def realesrgan_upscale_job(job: UpscaleQueueEntry):
200201
"Upscale image with RealESRGAN model"
201202

202203
image_bytes = job.data.image
203204
assert isinstance(image_bytes, bytes)
204205
job.data.image = convert_bytes_to_image_stream(image_bytes)
205206

206207
try:
207-
images: Union[List[Image.Image], List[str]]
208+
image: Image.Image
208209
time: float
209-
images, time = await gpu.generate(job)
210+
image, time = await gpu.upscale(job)
210211
except ModelNotLoadedError:
211212
raise HTTPException( # pylint: disable=raise-missing-from
212213
status_code=400, detail="Model is not loaded"
213214
)
214215

215-
if len(images) == 0:
216-
return {
217-
"time": time,
218-
"images": [],
219-
}
220-
elif isinstance(images[0], str):
221-
return {
222-
"time": time,
223-
"images": images,
224-
}
225-
else:
226-
return {
227-
"time": time,
228-
"images": [convert_image_to_base64(i) for i in images], # type: ignore
229-
}
216+
return {
217+
"time": time,
218+
"images": convert_image_to_base64(image), # type: ignore
219+
}
230220

231221

232222
@router.post("/generate-engine")
@@ -247,6 +237,15 @@ async def generate_aitemplate(request: AITemplateBuildRequest):
247237
return {"message": "Success"}
248238

249239

240+
@router.post("/generate-dynamic-aitemplate")
241+
async def generate_dynamic_aitemplate(request: AITemplateDynamicBuildRequest):
242+
"Generate a TensorRT engine from a local model"
243+
244+
await gpu.build_dynamic_aitemplate_engine(request)
245+
246+
return {"message": "Success"}
247+
248+
250249
@router.post("/generate-onnx")
251250
async def generate_onnx(request: ONNXBuildRequest):
252251
"Generate a TensorRT engine from a local model"

0 commit comments

Comments
 (0)