Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion tensorrt-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ fn tensorrt_configuration() {
println!("cargo:rustc-link-lib=dylib=nvinfer_plugin");
}

fn tensorrt_include_path() -> String {
match option_env!("TRT_INSTALL_DIR") {
Some(trt_include_dir) => format!("{}/include", trt_include_dir),
None => ".".to_string(),
}
}

// Not sure if I love this solution but I think it's relatively robust enough for now on Unix systems.
// Still have to thoroughly test what happens with a TRT library installed that's not done by the
// dpkg. It's possible that we'll just have to fall back to only supporting one system library and assuming that
Expand All @@ -41,10 +48,10 @@ fn main() -> Result<(), ()> {
cfg.define("TRT5", "");
let bindings = builder()
.clang_args(&["-x", "c++"])
.clang_args(&["-I", &tensorrt_include_path()[..]])
.header("trt-sys/tensorrt_api.h")
.size_t_is_usize(true)
.generate()?;

bindings.write_to_file("src/bindings.rs").unwrap();
}

Expand All @@ -55,6 +62,7 @@ fn main() -> Result<(), ()> {
let bindings = builder()
.clang_arg("-DTRT6")
.clang_args(&["-x", "c++"])
.clang_args(&["-I", &tensorrt_include_path()[..]])
.header("trt-sys/tensorrt_api.h")
.size_t_is_usize(true)
.generate()?;
Expand All @@ -69,6 +77,7 @@ fn main() -> Result<(), ()> {
let bindings = builder()
.clang_arg("-DTRT7")
.clang_args(&["-x", "c++"])
.clang_args(&["-I", &tensorrt_include_path()[..]])
.header("trt-sys/tensorrt_api.h")
.size_t_is_usize(true)
.generate()?;
Expand All @@ -79,6 +88,8 @@ fn main() -> Result<(), ()> {
let dst = cfg.build();
println!("cargo:rustc-link-search=native={}", dst.display());
println!("cargo:rustc-link-lib=static=trt-sys");

#[cfg(target_os = "linux")]
println!("cargo:rustc-link-lib=dylib=stdc++");

tensorrt_configuration();
Expand Down
6 changes: 5 additions & 1 deletion tensorrt-sys/trt-sys/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ endif()

set(CMAKE_CXX_STANDARD 17)

set(CMAKE_CXX_FLAGS "-fPIC -O3 -Wall -Wextra -Werror -Wno-unknown-pragmas -Wno-deprecated -Wno-deprecated-declarations")
if(WIN32)
set(CMAKE_CXX_FLAGS "-O3")
elseif(LINUX)
set(CMAKE_CXX_FLAGS "-fPIC -O3 -Wall -Wextra -Werror -Wno-unknown-pragmas -Wno-deprecated -Wno-deprecated-declarations")
endif()

file(GLOB source_files
"TRTLogger/*.cpp"
Expand Down
14 changes: 7 additions & 7 deletions tensorrt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ trt-7 = ["tensorrt-sys/trt-7"]

[dependencies]
# Uncomment when working locally
#tensorrt-sys = { path = "../tensorrt-sys" }
tensorrt-sys = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" }
#tensorrt_rs_derive = { path = "../tensorrt_rs_derive" }
tensorrt_rs_derive = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" }
ndarray = "0.13"
ndarray-image = "0.2"
tensorrt-sys = { path = "../tensorrt-sys" }
#tensorrt-sys = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" }
tensorrt_rs_derive = { path = "../tensorrt_rs_derive" }
#tensorrt_rs_derive = { git = "https://github.com/mstallmo/tensorrt-rs", branch = "develop" }
ndarray = "0.15.3"
ndarray-image = "0.3.0"
image = "0.23"
imageproc = "0.21.0"
imageproc = "0.22.0"
bitflags = "1.2"
num-traits = "0.2.12"
num-derive = "0.3.2"
Expand Down
2 changes: 1 addition & 1 deletion tensorrt/examples/mnist_uff/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() {
let mut output = ndarray::Array1::<f32>::zeros(10);
let outputs = vec![ExecuteInput::Float(&mut output)];
context
.execute(ExecuteInput::Float(&mut pre_processed), outputs)
.execute(ExecuteInput::Float(&mut pre_processed), outputs, None)
.unwrap();
println!("output: {}", output);
}
2 changes: 1 addition & 1 deletion tensorrt/examples/onnx/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn main() {
let mut output = ndarray::Array1::<f32>::zeros(1000);
let outputs = vec![ExecuteInput::Float(&mut output)];
context
.execute(ExecuteInput::Float(&mut pre_processed), outputs)
.execute(ExecuteInput::Float(&mut pre_processed), outputs, None)
.unwrap();
println!("output: {}", output);
}
2 changes: 1 addition & 1 deletion tensorrt/examples/ssd_uff/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fn infer(engine: &Engine, input: &mut Array1<f32>) -> (ndarray::Array1<f32>, nda
ExecuteInput::Integer(&mut keep_count),
];
let execute_input = ExecuteInput::Float(input);
context.execute(execute_input, outputs).unwrap();
context.execute(execute_input, outputs, None).unwrap();

(top_detections, keep_count)
}
Expand Down
9 changes: 7 additions & 2 deletions tensorrt/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl Context {
&self,
input_data: ExecuteInput<D1>,
mut output_data: Vec<ExecuteInput<D2>>,
batch_size: Option<i32>,
) -> Result<(), Error> {
let mut buffers = Vec::<DeviceBuffer>::with_capacity(output_data.len() + 1);
let dev_buffer = match input_data {
Expand All @@ -146,7 +147,11 @@ impl Context {
.collect::<Vec<*mut c_void>>();

unsafe {
execute(self.internal_context, bindings.as_mut_ptr(), 1);
execute(
self.internal_context,
bindings.as_mut_ptr(),
batch_size.unwrap_or(1),
);
}

for (idx, output) in buffers.iter().skip(1).enumerate() {
Expand Down Expand Up @@ -179,7 +184,7 @@ mod tests {
use crate::data_size::GB;
use crate::dims::DimsCHW;
use crate::engine::Engine;
use crate::profiler::RustProfiler;
// use crate::profiler::RustProfiler;
use crate::runtime::Logger;
use crate::uff::{UffFile, UffInputOrder, UffParser};
use lazy_static::lazy_static;
Expand Down