Skip to content

Commit f90f43c

Browse files
authored
Merge pull request #13 from KernelA/new-params
New parameters added :minor:
2 parents 8d9d1a5 + d536937 commit f90f43c

12 files changed

+224
-50
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,5 @@ cython_debug/
140140
*.jpg
141141
test.ipynb
142142
**/nms/*.c
143-
wheelhouse/
143+
wheelhouse/
144+
class_colors.json

processing.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import argparse
22
import os
3+
import logging
34

45
from yolo_models.processing.detection_processing import DetectorProcess
56
from yolo_models.log_set import init_logging
67

78

89
def main(args):
10+
logger = logging.getLogger()
11+
912
with DetectorProcess(args.checkpoint_path,
1013
args.shared_update_mem_name,
1114
args.shared_params_mem_name,
@@ -14,16 +17,29 @@ def main(args):
1417
args.image_height,
1518
args.num_channels,
1619
args.image_type) as det:
20+
21+
if os.path.exists(args.class_colormap):
22+
logger.info("Load class colors from '%s'", args.class_colormap)
23+
det.load_class_colormap(args.class_colormap)
24+
else:
25+
logger.info("Cannot find a file with class colors. Create it with defaults colors.")
26+
det.save_class_colormap(args.class_colormap)
27+
1728
det.start_processing()
1829

1930

31+
def check_file(path: str):
32+
if not os.path.isfile(path):
33+
raise FileNotFoundError(f"Cannot find: '{path}'")
34+
35+
2036
if __name__ == "__main__":
2137
parser = argparse.ArgumentParser()
2238

23-
parser.add_argument("-iw", "--image_width", type=int, default=640,
24-
help="Width of image to transfer between processes")
25-
parser.add_argument("-ih", "--image_height", type=int, default=640,
26-
help="Height of image to transfer between processes")
39+
parser.add_argument("-iw", "--image_width", type=int, default=1280,
40+
help="A maximum width of image to transfer between processes")
41+
parser.add_argument("-ih", "--image_height", type=int, default=1280,
42+
help="A maximum height of image to transfer between processes")
2743
parser.add_argument("-c", "--num_channels", type=int, default=3,
2844
help="A number of channels in image. 3 for RGB")
2945
parser.add_argument("--image_type", type=str,
@@ -36,13 +52,12 @@ def main(args):
3652
help="Name of shared memory for transfering of parameters")
3753
parser.add_argument("-p", "--checkpoint_path", type=str,
3854
required=True, help="A path to model .pt")
39-
parser.add_argument("--log_config", type=str, default="log_settings.yaml",
55+
parser.add_argument("--log_config", type=str, default=None,
4056
help="A path to settings for logging")
57+
parser.add_argument("--class_colormap", type=str, default="class_colors.json",
58+
help="A path to json with colors for each class")
4159

4260
args = parser.parse_args()
43-
44-
if not os.path.isfile(args.checkpoint_path):
45-
raise FileNotFoundError(f"Cannot find: '{args.checkpoint_path}'")
46-
61+
check_file(args.checkpoint_path)
4762
init_logging(log_config=args.log_config)
4863
main(args)

requirements.inference.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
opencv-python-headless>=4.7.0.0
22
onnxruntime>=1.11.0
3-
matplotlib~=3.7

setup.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def get_req(filepath: str):
4646
setup(install_requires=get_req("requirements.inference.txt"),
4747
version=get_version(),
4848
packages=find_packages(include=[f"{PACKAGE_NAME}*"]),
49+
package_data={"": [f"{PACKAGE_NAME}/log_set/*.yaml"]},
4950
extras_require={
50-
"torch": get_req("requirements.torch.gpu.txt")
51-
},
52-
ext_modules=cythonize(exts,
53-
compiler_directives={"language_level": "3"},
54-
)
55-
)
51+
"torch": get_req("requirements.torch.gpu.txt")},
52+
ext_modules=cythonize(exts,
53+
compiler_directives={"language_level": "3"},
54+
)
55+
)

tests/test_color_conv.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import pytest
2+
3+
from yolo_models.processing.color_utils import hex_to_rgb, rgb_to_hex
4+
5+
6+
@pytest.mark.parametrize("r", [0, 128, 255])
7+
@pytest.mark.parametrize("g", [0, 128, 255])
8+
@pytest.mark.parametrize("b", [0, 128, 255])
9+
def test_rgb_seq(r, g, b):
10+
color = (r, g, b)
11+
assert hex_to_rgb(rgb_to_hex(color)) == color

touch_designer.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
import time
12
from multiprocessing import shared_memory
23

3-
import numpy as np
44
import cv2
5-
6-
from yolo_models.processing.info import BufferStates, States, ParamsIndex
5+
import numpy as np
6+
from yolo_models.processing.info import (BufferStates, DrawInfo, ParamsIndex,
7+
States)
78

89
# Sync with external process
910
SHARED_MEM_PARAMS_LIST = shared_memory.ShareableList(name="params")
1011

11-
WIDTH = SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_WIDTH]
12-
HEIGHT = SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_HEIGHT]
12+
MAX_IMAGE_WIDTH = SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_WIDTH]
13+
MAX_IMAGE_HEIGHT = SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_HEIGHT]
1314
DTYPE = SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_DTYPE]
1415
NUM_CHANNELS = SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_CHANNELS]
1516
EXIT = False
@@ -21,28 +22,33 @@
2122
SHARED_MEM_ARRAY = shared_memory.SharedMemory(
2223
name=SHARED_MEM_PARAMS_LIST[ParamsIndex.SHARED_ARRAY_MEM_NAME], create=False)
2324

24-
ARRAY = np.ndarray((WIDTH, HEIGHT, NUM_CHANNELS), dtype=DTYPE, buffer=SHARED_MEM_ARRAY.buf)
25+
ARRAY = np.ndarray((MAX_IMAGE_WIDTH, MAX_IMAGE_HEIGHT, NUM_CHANNELS),
26+
dtype=DTYPE, buffer=SHARED_MEM_ARRAY.buf)
2527

2628

2729
def onSetupParameters(scriptOp):
2830
page = scriptOp.appendCustomPage("Detection")
2931
p = page.appendFloat("Nms", label="Intersection ratio between bboxes")
30-
p.default = 0.5
3132
p.min = 0
3233
p.max = 1
33-
p = page.appendFloat("Score", label="Minimum score for object")
3434
p.default = 0.5
35+
p = page.appendFloat("Score", label="Minimum score for object")
3536
p.min = 0.1
3637
p.max = 1
38+
p.default = 0.5
3739
p = page.appendInt("Maxk", label="Maximum number of objects to detect based on score")
38-
p.default = 5
3940
p.min = 0
4041
p.max = 1000
42+
p.default = 5
4143
p = page.appendFloat("Eta", label="Filtering")
42-
p.default = 1.0
4344
p.min = 0
4445
p.clampMin = True
4546
p.max = 1000
47+
p.default = 1.0
48+
p = page.appendToggle("Drawtext", label="Draw text labels")
49+
p.default = False
50+
p = page.appendToggle("Drawscore", label="Draw score labels")
51+
p.default = False
4652
return
4753

4854

@@ -71,31 +77,65 @@ def onCook(scriptOp):
7177

7278
video_in = scriptOp.inputs[0]
7379
# By default, the image is flipped up. We flip it early
74-
frame = video_in.numpyArray(delayed=True, writable=False)
80+
image = video_in.numpyArray(delayed=True, writable=False)
7581

76-
if frame is None:
82+
if image is None:
83+
scriptOp.clear()
7784
return
7885

79-
if COLOR_CONVERSION is not None:
80-
image = cv2.cvtColor(frame, COLOR_CONVERSION)
86+
if image.shape[0] > MAX_IMAGE_HEIGHT:
87+
debug("Too large image height")
88+
scriptOp.clear()
89+
return
8190

82-
image = cv2.resize(image, (WIDTH, HEIGHT))
91+
if image.shape[1] > MAX_IMAGE_WIDTH:
92+
debug("Too large image width")
93+
scriptOp.clear()
94+
return
95+
96+
if COLOR_CONVERSION is not None:
97+
image = cv2.cvtColor(image, COLOR_CONVERSION)
8398

99+
SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_HEIGHT] = image.shape[0]
100+
SHARED_MEM_PARAMS_LIST[ParamsIndex.IMAGE_WIDTH] = image.shape[1]
84101
SHARED_MEM_PARAMS_LIST[ParamsIndex.SCORE_THRESH] = scriptOp.par.Score.eval()
85102
SHARED_MEM_PARAMS_LIST[ParamsIndex.IOU_THRESH] = scriptOp.par.Nms.eval()
86103
SHARED_MEM_PARAMS_LIST[ParamsIndex.TOP_K] = scriptOp.par.Maxk.eval()
87104
SHARED_MEM_PARAMS_LIST[ParamsIndex.ETA] = scriptOp.par.Eta.eval()
88-
np.copyto(ARRAY, image)
105+
draw_info = DrawInfo.DRAW_BBOX
106+
107+
if scriptOp.par.Drawtext.eval():
108+
draw_info |= DrawInfo.DRAW_TEXT
109+
110+
if scriptOp.par.Drawscore.eval():
111+
draw_info |= DrawInfo.DRAW_CONF
112+
113+
SHARED_MEM_PARAMS_LIST[ParamsIndex.DRAW_INFO] = int(draw_info)
89114

115+
ARRAY[:image.shape[0], :image.shape[1]] = image
90116
SHARED_MEM_UPDATE_STATES.buf[BufferStates.SERVER] = States.READY_SERVER_MESSAGE.value[0]
91117

118+
start_time = time.monotonic()
119+
120+
is_skip = False
121+
92122
while not EXIT and SHARED_MEM_UPDATE_STATES.buf[BufferStates.SERVER_ALIVE] == States.IS_SERVER_ALIVE.value[0] and SHARED_MEM_UPDATE_STATES.buf[BufferStates.CLIENT] != States.READY_CLIENT_MESSAGE.value[0]:
93-
pass
123+
time.sleep(1e-3)
124+
elapsed = time.monotonic() - start_time
125+
126+
if elapsed > 1:
127+
debug("Too long processing copy frame as is")
128+
is_skip = True
129+
break
94130

95131
if SHARED_MEM_UPDATE_STATES.buf[BufferStates.SERVER_ALIVE] != States.IS_SERVER_ALIVE.value[0]:
132+
scriptOp.clear()
96133
raise ValueError("Server process died")
97134

98-
scriptOp.copyNumpyArray(ARRAY)
99-
SHARED_MEM_UPDATE_STATES.buf[BufferStates.CLIENT] = States.NULL_STATE.value[0]
135+
if is_skip:
136+
scriptOp.clear()
137+
return
100138

139+
scriptOp.copyNumpyArray(ARRAY[:image.shape[0], :image.shape[1]])
140+
SHARED_MEM_UPDATE_STATES.buf[BufferStates.CLIENT] = States.NULL_STATE.value[0]
101141
return

yolo_models/log_set/init_log.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import os
22
from typing import Optional
33
import logging.config
4+
import pkg_resources
45

56
import yaml
67

78

89
def init_logging(log_config: Optional[str] = None, log_env_var: str = "LOG_CONFIG"):
9-
path_to_config = log_config if log_config is not None else os.environ[log_env_var]
10+
path_to_config = log_config if log_config is not None else os.environ.get(log_env_var)
11+
12+
if path_to_config is None:
13+
path_to_config = pkg_resources.resource_filename(__name__, "log_settings.yaml")
1014

1115
with open(path_to_config, "rb") as file:
1216
log_config = yaml.safe_load(file)
File renamed without changes.

yolo_models/processing/base_processing.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import numpy as np
66

7-
from .info import ParamsIndex, BufferStates, States
7+
from .info import ParamsIndex, BufferStates, States, DrawInfo
8+
89

910
class BaseProcessServer:
1011
def __init__(self,
@@ -15,7 +16,7 @@ def __init__(self,
1516
image_height: int,
1617
num_channels: int,
1718
image_dtype):
18-
self._logging = logging.getLogger("server_processing")
19+
self._logging = logging.getLogger("server_processing")
1920
dtype_size = np.dtype(image_dtype).itemsize
2021
self._image_size_bytes = image_height * image_with * num_channels * dtype_size
2122
self._image_width = image_with
@@ -30,9 +31,21 @@ def __init__(self,
3031
self._params_shared_mem_name = params_shared_mem_name
3132
self._array_shared_mem_name = array_shared_mem_name
3233

34+
def _get_shared_mem_update(self, create: bool):
35+
return shared_memory.SharedMemory(
36+
name=self._update_shared_mem_name, create=create, size=len(BufferStates))
37+
38+
def _get_shared_mem_array(self, create: bool):
39+
return shared_memory.SharedMemory(
40+
name=self._array_shared_mem_name, create=create, size=self._image_size_bytes)
41+
3342
def init_mem(self):
3443
assert self._sh_mem_update is None, "Memory already initialized"
35-
self._sh_mem_update = shared_memory.SharedMemory(name=self._update_shared_mem_name, create=True, size=len(BufferStates))
44+
try:
45+
self._sh_mem_update = self._get_shared_mem_update(True)
46+
except FileExistsError:
47+
self._logging.warning("Cannot create new shared memory. Use existed")
48+
self._sh_mem_update = self._get_shared_mem_update(False)
3649

3750
params = [None] * len(ParamsIndex)
3851
params[ParamsIndex.ETA] = 1.0
@@ -45,11 +58,17 @@ def init_mem(self):
4558
params[ParamsIndex.SHARED_ARRAY_MEM_NAME] = self._array_shared_mem_name
4659
params[ParamsIndex.SHARD_STATE_MEM_NAME] = self._update_shared_mem_name
4760
params[ParamsIndex.IMAGE_DTYPE] = self._image_dtype
61+
params[ParamsIndex.DRAW_INFO] = int(DrawInfo.DRAW_BBOX)
4862

4963
self._sh_mem_params = shared_memory.ShareableList(
5064
name=self._params_shared_mem_name, sequence=params
5165
)
52-
self._sh_mem_array = shared_memory.SharedMemory(name=self._array_shared_mem_name, create=True, size=self._image_size_bytes)
66+
67+
try:
68+
self._sh_mem_array = self._get_shared_mem_array(True)
69+
except FileExistsError:
70+
self._sh_mem_array = self._get_shared_mem_array(False)
71+
5372
self._shared_array = np.ndarray(
5473
(self._image_height, self._image_width, self._num_channels),
5574
dtype=self._image_dtype, buffer=self._sh_mem_array.buf)
@@ -95,7 +114,14 @@ def start_processing(self):
95114

96115
self._sh_mem_update.buf[BufferStates.SERVER] = States.NULL_STATE.value[0]
97116

98-
self.process(self._shared_array, self._sh_mem_params)
117+
image_height = self._sh_mem_params[ParamsIndex.IMAGE_HEIGHT]
118+
image_width = self._sh_mem_params[ParamsIndex.IMAGE_WIDTH]
119+
120+
actual_image = self._shared_array[:image_height, :image_width]
121+
122+
self.process(actual_image, self._sh_mem_params)
123+
self._shared_array[:image_height, :image_width] = actual_image
124+
99125
self._sh_mem_update.buf[BufferStates.CLIENT] = States.READY_CLIENT_MESSAGE.value[0]
100126

101127
def process(self, image: np.ndarray, params: list):

yolo_models/processing/color_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Tuple
2+
import re
3+
4+
HEX_REGEX = re.compile("^#[0-9a-f]{6}$", re.IGNORECASE)
5+
6+
7+
def check_regex_color(hex_color: str):
8+
if HEX_REGEX.match(hex_color) is None:
9+
raise ValueError(f"Incorrect hex color: '{hex_color}'")
10+
11+
12+
def rgb_to_hex(rgb: Tuple[int]) -> str:
13+
return f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
14+
15+
16+
def hex_to_rgb(hex_color: str) -> Tuple[int]:
17+
check_regex_color(hex_color)
18+
return tuple(int(hex_color[i: i + 2], 16) for i in range(1, len(hex_color), 2))

0 commit comments

Comments
 (0)