Skip to content

Tentatively eliminate graph break overhead #3741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 42 additions & 41 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from contextlib import nullcontext
from tempfile import tempdir
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
Expand Down Expand Up @@ -218,7 +217,8 @@ def __init__(
self.requires_output_allocator = requires_output_allocator
self.output_allocator: Optional[DynamicOutputAllocator] = None
self.use_output_allocator_outputs = False

self.device = torch.cuda.current_device()
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

Expand Down Expand Up @@ -263,7 +263,12 @@ def setup_engine(self) -> None:
assert (
self.target_platform == Platform.current_platform()
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"

self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()
self.initialized = True
runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
Expand All @@ -286,10 +291,14 @@ def setup_engine(self) -> None:
for output_name in self.output_names
]
self.output_shapes = [
self.engine.get_tensor_shape(output_name)
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]

self.shape_key = "".join(
str(tuple(t)).replace(" ", "") for t in self.input_shapes
)

if self.requires_output_allocator:
self.create_output_allocator()

Expand Down Expand Up @@ -370,9 +379,9 @@ def setup_input_tensors(
+ contiguous_inputs[i + 1 :]
)

assert (
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
# assert (
# contiguous_inputs[i].dtype == self.input_dtypes[i]
# ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
Expand Down Expand Up @@ -409,7 +418,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
output = torch.empty(
size=self.output_shapes[o],
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
device=self.device,
)
outputs.append(output)
return outputs
Expand Down Expand Up @@ -480,10 +489,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]
# self.output_shapes = [
# tuple(self.context.get_tensor_shape(output_name))
# for output_name in self.output_names
# ]
if DYNAMIC_DIM in self.output_shapes:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
Expand All @@ -510,42 +519,36 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if self.profiling_enabled
else nullcontext()
):
self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()

self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):
if self.cudagraphs_enabled:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()
# with torch.cuda.stream(self._engine_stream):
# if self.cudagraphs_enabled:
# if need_cudagraphs_record:
# self.cudagraph = torch.cuda.CUDAGraph()

if self.profiling_enabled:
self.cudagraph.enable_debug_mode()
# if self.profiling_enabled:
# self.cudagraph.enable_debug_mode()

with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)
# with torch.cuda.graph(
# self.cudagraph, stream=self._engine_stream
# ):
# self.context.execute_async_v3(
# self._engine_stream.cuda_stream
# )

if self.profiling_enabled:
import tempfile
# if self.profiling_enabled:
# import tempfile

with tempfile.TemporaryDirectory() as tmpdir:
self.cudagraph.debug_dump(
f"{tempdir}/{self.name}_cudagraph.dot"
)
# with tempfile.TemporaryDirectory() as tmpdir:
# self.cudagraph.debug_dump(
# f"{tempdir}/{self.name}_cudagraph.dot"
# )

self.cudagraph.replay() # type: ignore
# self.cudagraph.replay() # type: ignore

else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)
# else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)

self._caller_stream.wait_stream(self._engine_stream)

Expand Down Expand Up @@ -646,8 +649,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:

return outputs

self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()

# Run forward function
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
Expand Down
Loading