diff --git a/tensorrt-sys/build.rs b/tensorrt-sys/build.rs index 28d0d99..600aa3d 100644 --- a/tensorrt-sys/build.rs +++ b/tensorrt-sys/build.rs @@ -24,6 +24,13 @@ fn tensorrt_configuration() { println!("cargo:rustc-link-lib=dylib=nvinfer_plugin"); } +fn tensorrt_include_path() -> String { + match option_env!("TRT_INSTALL_DIR") { + Some(trt_include_dir) => format!("{}/include", trt_include_dir), + None => ".".to_string(), + } +} + // Not sure if I love this solution but I think it's relatively robust enough for now on Unix systems. // Still have to thoroughly test what happens with a TRT library installed that's not done by the // dpkg. It's possible that we'll just have to fall back to only supporting one system library and assuming that @@ -41,10 +48,10 @@ fn main() -> Result<(), ()> { cfg.define("TRT5", ""); let bindings = builder() .clang_args(&["-x", "c++"]) + .clang_args(&["-I", &tensorrt_include_path()[..]]) .header("trt-sys/tensorrt_api.h") .size_t_is_usize(true) .generate()?; - bindings.write_to_file("src/bindings.rs").unwrap(); } @@ -55,6 +62,7 @@ fn main() -> Result<(), ()> { let bindings = builder() .clang_arg("-DTRT6") .clang_args(&["-x", "c++"]) + .clang_args(&["-I", &tensorrt_include_path()[..]]) .header("trt-sys/tensorrt_api.h") .size_t_is_usize(true) .generate()?; @@ -69,6 +77,7 @@ fn main() -> Result<(), ()> { let bindings = builder() .clang_arg("-DTRT7") .clang_args(&["-x", "c++"]) + .clang_args(&["-I", &tensorrt_include_path()[..]]) .header("trt-sys/tensorrt_api.h") .size_t_is_usize(true) .generate()?; @@ -79,6 +88,8 @@ fn main() -> Result<(), ()> { let dst = cfg.build(); println!("cargo:rustc-link-search=native={}", dst.display()); println!("cargo:rustc-link-lib=static=trt-sys"); + + #[cfg(target_os = "linux")] println!("cargo:rustc-link-lib=dylib=stdc++"); tensorrt_configuration(); diff --git a/tensorrt-sys/trt-sys/CMakeLists.txt b/tensorrt-sys/trt-sys/CMakeLists.txt index 10785fd..2b4748e 100644 --- a/tensorrt-sys/trt-sys/CMakeLists.txt +++ b/tensorrt-sys/trt-sys/CMakeLists.txt @@ -14,7 +14,11 @@ endif() set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_FLAGS "-fPIC -O3 -Wall -Wextra -Werror -Wno-unknown-pragmas -Wno-deprecated -Wno-deprecated-declarations") +if(WIN32) + set(CMAKE_CXX_FLAGS "-O3") +elseif(LINUX) + set(CMAKE_CXX_FLAGS "-fPIC -O3 -Wall -Wextra -Werror -Wno-unknown-pragmas -Wno-deprecated -Wno-deprecated-declarations") +endif() file(GLOB source_files "TRTLogger/*.cpp" diff --git a/tensorrt/Cargo.toml b/tensorrt/Cargo.toml index dd010d3..0324f2c 100644 --- a/tensorrt/Cargo.toml +++ b/tensorrt/Cargo.toml @@ -19,14 +19,14 @@ trt-7 = ["tensorrt-sys/trt-7"] [dependencies] # Uncomment when working locally -#tensorrt-sys = { path = "../tensorrt-sys" } -tensorrt-sys = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" } -#tensorrt_rs_derive = { path = "../tensorrt_rs_derive" } -tensorrt_rs_derive = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" } -ndarray = "0.13" -ndarray-image = "0.2" +tensorrt-sys = { path = "../tensorrt-sys" } +#tensorrt-sys = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" } +tensorrt_rs_derive = { path = "../tensorrt_rs_derive" } +#tensorrt_rs_derive = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" } +ndarray = "0.15.3" +ndarray-image = "0.3.0" image = "0.23" -imageproc = "0.21.0" +imageproc = "0.22.0" bitflags = "1.2" num-traits = "0.2.12" num-derive = "0.3.2" diff --git a/tensorrt/examples/mnist_uff/main.rs b/tensorrt/examples/mnist_uff/main.rs index ec1b4cf..44a43fa 100644 --- a/tensorrt/examples/mnist_uff/main.rs +++ b/tensorrt/examples/mnist_uff/main.rs @@ -52,7 +52,7 @@ fn main() { let mut output = ndarray::Array1::::zeros(10); let outputs = vec![ExecuteInput::Float(&mut output)]; context - .execute(ExecuteInput::Float(&mut pre_processed), outputs) + .execute(ExecuteInput::Float(&mut pre_processed), outputs, None) .unwrap(); println!("output: {}", output); } diff --git a/tensorrt/examples/onnx/main.rs b/tensorrt/examples/onnx/main.rs index 1f3229d..732b853 100644 --- a/tensorrt/examples/onnx/main.rs +++ b/tensorrt/examples/onnx/main.rs @@ -55,7 +55,7 @@ fn main() { let mut output = ndarray::Array1::::zeros(1000); let outputs = vec![ExecuteInput::Float(&mut output)]; context - .execute(ExecuteInput::Float(&mut pre_processed), outputs) + .execute(ExecuteInput::Float(&mut pre_processed), outputs, None) .unwrap(); println!("output: {}", output); } diff --git a/tensorrt/examples/ssd_uff/main.rs b/tensorrt/examples/ssd_uff/main.rs index 261f033..cd98ffd 100644 --- a/tensorrt/examples/ssd_uff/main.rs +++ b/tensorrt/examples/ssd_uff/main.rs @@ -58,7 +58,7 @@ fn infer(engine: &Engine, input: &mut Array1) -> (ndarray::Array1, nda ExecuteInput::Integer(&mut keep_count), ]; let execute_input = ExecuteInput::Float(input); - context.execute(execute_input, outputs).unwrap(); + context.execute(execute_input, outputs, None).unwrap(); (top_detections, keep_count) } diff --git a/tensorrt/src/context.rs b/tensorrt/src/context.rs index e96bdc8..fefda54 100644 --- a/tensorrt/src/context.rs +++ b/tensorrt/src/context.rs @@ -122,6 +122,7 @@ impl Context { &self, input_data: ExecuteInput, mut output_data: Vec>, + batch_size: Option, ) -> Result<(), Error> { let mut buffers = Vec::::with_capacity(output_data.len() + 1); let dev_buffer = match input_data { @@ -146,7 +147,11 @@ impl Context { .collect::>(); unsafe { - execute(self.internal_context, bindings.as_mut_ptr(), 1); + execute( + self.internal_context, + bindings.as_mut_ptr(), + batch_size.unwrap_or(1), + ); } for (idx, output) in buffers.iter().skip(1).enumerate() { @@ -179,7 +184,7 @@ mod tests { use crate::data_size::GB; use crate::dims::DimsCHW; use crate::engine::Engine; - use crate::profiler::RustProfiler; + // use crate::profiler::RustProfiler; use crate::runtime::Logger; use crate::uff::{UffFile, UffInputOrder, UffParser}; use lazy_static::lazy_static;