Skip to content

Tentatively eliminate graph break overhead #3741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 56 additions & 49 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from contextlib import nullcontext
from tempfile import tempdir
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
Expand Down Expand Up @@ -174,6 +173,8 @@ def __init__(
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
self.output_tensors: Optional[List[torch.Tensor]] = None
self.sync_stream = True

# TODO: Make the below a Dictionary {shape: cudagraph}
self.shape_key: Optional[str] = None
Expand Down Expand Up @@ -218,7 +219,8 @@ def __init__(
self.requires_output_allocator = requires_output_allocator
self.output_allocator: Optional[DynamicOutputAllocator] = None
self.use_output_allocator_outputs = False

self.device = torch.cuda.current_device()
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

Expand Down Expand Up @@ -263,6 +265,15 @@ def setup_engine(self) -> None:
assert (
self.target_platform == Platform.current_platform()
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
# Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
# otherwise, use the caller stream and disable stream synchronization
self._caller_stream = torch.cuda.current_stream()
if self._caller_stream == torch.cuda.default_stream():
self._engine_stream = torch.cuda.Stream()
self.sync_stream = True
else:
self._engine_stream = self._caller_stream
self.sync_stream = False

self.initialized = True
runtime = trt.Runtime(TRT_LOGGER)
Expand All @@ -286,10 +297,14 @@ def setup_engine(self) -> None:
for output_name in self.output_names
]
self.output_shapes = [
self.engine.get_tensor_shape(output_name)
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]

self.shape_key = "".join(
str(tuple(t)).replace(" ", "") for t in self.input_shapes
)

if self.requires_output_allocator:
self.create_output_allocator()

Expand Down Expand Up @@ -370,9 +385,9 @@ def setup_input_tensors(
+ contiguous_inputs[i + 1 :]
)

assert (
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
# assert (
# contiguous_inputs[i].dtype == self.input_dtypes[i]
# ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

if need_cudagraphs_record:
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
Expand Down Expand Up @@ -409,7 +424,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
output = torch.empty(
size=self.output_shapes[o],
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
device=self.device,
)
outputs.append(output)
return outputs
Expand Down Expand Up @@ -480,15 +495,14 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]

if DYNAMIC_DIM in self.output_shapes:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
)
outputs = self.create_output_tensors()
if self.output_tensors is None:
self.output_tensors = self.create_output_tensors()
outputs = self.output_tensors

for o, output_name in enumerate(self.output_names):
if need_cudagraphs_record:
Expand All @@ -510,44 +524,39 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if self.profiling_enabled
else nullcontext()
):
self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()

self._engine_stream.wait_stream(self._caller_stream)
if self.sync_stream:
self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):
if self.cudagraphs_enabled:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()
if self.cudagraphs_enabled:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()

if self.profiling_enabled:
self.cudagraph.enable_debug_mode()
if self.profiling_enabled:
self.cudagraph.enable_debug_mode()

with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)
with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
):
self.context.execute_async_v3(
self._engine_stream.cuda_stream
)

if self.profiling_enabled:
import tempfile
if self.profiling_enabled:
import tempfile

with tempfile.TemporaryDirectory() as tmpdir:
self.cudagraph.debug_dump(
f"{tempdir}/{self.name}_cudagraph.dot"
)
with tempfile.TemporaryDirectory() as tmpdir:
self.cudagraph.debug_dump(
f"{tmpdir}/{self.name}_cudagraph.dot"
)

self.cudagraph.replay() # type: ignore
self.cudagraph.replay() # type: ignore

else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)
else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)

self._caller_stream.wait_stream(self._engine_stream)
if self.sync_stream:
self._caller_stream.wait_stream(self._engine_stream)

if self.use_pre_allocated_outputs:
self.pre_allocated_outputs = self.create_output_tensors()
Expand Down Expand Up @@ -646,8 +655,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:

return outputs

self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()

# Run forward function
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
Expand Down Expand Up @@ -752,13 +759,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
# Representation of input shapes to a given model
# Shapes are concatenated as so:
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
tensor_inputs = []
for t in inputs:
if not isinstance(t, torch.Tensor):
return True
tensor_inputs.append(t)
if not all(isinstance(t, torch.Tensor) for t in inputs):
return True

new_shape_key = "".join(
str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
str(tuple(t.shape)).replace(" ", "")
for t in inputs
if isinstance(t, torch.Tensor)
)

# If the new shape key differs from the existing one,
Expand Down