88 4. Produces numerically accurate results vs. PyTorch eager.
99"""
1010
11- import unittest
12-
1311import torch
1412import torch .nn as nn
1513import torch_tensorrt
1816import triton
1917import triton .language as tl
2018
21- from . _e2e_common import _compile_and_run , assert_trt_compiled
19+ from torch . testing . _internal . common_utils import TestCase , run_tests
2220
2321
2422# ---------------------------------------------------------------------------
@@ -84,7 +82,7 @@ def _launch_add_two(x, y, out, BLOCK_SIZE=128):
8482# Tests
8583# ---------------------------------------------------------------------------
8684
87- class TestTrtPluginsCustomOpE2E (unittest . TestCase ):
85+ class TestTrtPluginsCustomOpE2E (TestCase ):
8886 """E2E: custom_op(impl=tta.custom_plugin(...)) compiles and runs correctly."""
8987
9088 def test_add_one_single_input (self ):
@@ -94,11 +92,12 @@ class M(nn.Module):
9492 def forward (self , x ):
9593 return torch .ops .torchtrt_trt_plugins_e2e .add_one .default (x )
9694
97- trt_model , trt_out , eager_out = _compile_and_run (M (), (torch .randn (128 ),))
98- assert_trt_compiled (
99- self , trt_model , trt_out , eager_out ,
100- expected_tta_metadata = [{"backend" : "triton" }],
101- )
95+ model = M ().eval ().cuda ()
96+ inputs = [torch .randn (128 , device = "cuda" )]
97+
98+ torch ._dynamo .reset ()
99+ compiled = torch_tensorrt .compile (model , inputs = inputs , min_block_size = 1 )
100+ torch .testing .assert_close (compiled (* inputs ), model (* inputs ), rtol = 1e-3 , atol = 1e-3 )
102101
103102 def test_add_two_two_inputs (self ):
104103 """Two-input plugin: output = x + y"""
@@ -107,13 +106,12 @@ class M(nn.Module):
107106 def forward (self , x , y ):
108107 return torch .ops .torchtrt_trt_plugins_e2e .add_two .default (x , y )
109108
110- trt_model , trt_out , eager_out = _compile_and_run (
111- M (), (torch .randn (256 ), torch .randn (256 ))
112- )
113- assert_trt_compiled (
114- self , trt_model , trt_out , eager_out ,
115- expected_tta_metadata = [{"backend" : "triton" }],
116- )
109+ model = M ().eval ().cuda ()
110+ inputs = [torch .randn (256 , device = "cuda" ), torch .randn (256 , device = "cuda" )]
111+
112+ torch ._dynamo .reset ()
113+ compiled = torch_tensorrt .compile (model , inputs = inputs , min_block_size = 1 )
114+ torch .testing .assert_close (compiled (* inputs ), model (* inputs ), rtol = 1e-3 , atol = 1e-3 )
117115
118116 def test_add_one_in_larger_graph (self ):
119117 """Plugin fused inside a larger graph with aten ops."""
@@ -124,12 +122,13 @@ def forward(self, x):
124122 x = torch .ops .torchtrt_trt_plugins_e2e .add_one .default (x )
125123 return x + 0.5
126124
127- trt_model , trt_out , eager_out = _compile_and_run (M (), (torch .randn (256 ),))
128- assert_trt_compiled (
129- self , trt_model , trt_out , eager_out ,
130- expected_tta_metadata = [{"backend" : "triton" }],
131- )
125+ model = M ().eval ().cuda ()
126+ inputs = [torch .randn (256 , device = "cuda" )]
127+
128+ torch ._dynamo .reset ()
129+ compiled = torch_tensorrt .compile (model , inputs = inputs , min_block_size = 1 )
130+ torch .testing .assert_close (compiled (* inputs ), model (* inputs ), rtol = 1e-3 , atol = 1e-3 )
132131
133132
134133if __name__ == "__main__" :
135- unittest . main ()
134+ run_tests ()
0 commit comments