Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/relax/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ struct TensorRTCompilerConfigNode : public ffi::Object {
"TensorRT version as (major, minor, patch).",
refl::DefaultValue(ffi::Array<int64_t>({6, 0, 1})))
.def_ro("use_implicit_batch", &TensorRTCompilerConfigNode::use_implicit_batch,
"Use implicit batch", refl::DefaultValue(true))
"Use implicit batch (removed in TensorRT 10; networks are always explicit-batch)",
refl::DefaultValue(false))
.def_ro("max_workspace_size", &TensorRTCompilerConfigNode::max_workspace_size,
"Max workspace size", refl::DefaultValue(size_t(1) << 30))
.def_ro("remove_no_mac_subgraphs", &TensorRTCompilerConfigNode::remove_no_mac_subgraphs,
Expand Down
129 changes: 55 additions & 74 deletions src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,36 +40,24 @@ namespace contrib {

TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>& data_entry,
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16,
int batch_size, nvinfer1::IInt8Calibrator* calibrator)
: data_entry_(data_entry),
size_t max_workspace_size, bool use_fp16,
nvinfer1::IInt8Calibrator* calibrator)
: trt_logger_(logger),
data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
use_int8_(false),
batch_size_(batch_size),
calibrator_(calibrator) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
builder_ = nvinfer1::createInferBuilder(*trt_logger_);

#if TRT_VERSION_GE(6, 0, 1)
// Use INetworkV2.
auto flags =
1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
if (use_implicit_batch_) {
flags = 0U;
builder_->setMaxBatchSize(batch_size_);
}
// TensorRT 10 removed implicit-batch mode and the kEXPLICIT_BATCH creation flag; every network is
// explicit-batch, so the batch dimension is simply dimension 0 of each binding and is varied
// through optimization profiles rather than IBuilder::setMaxBatchSize.
if (calibrator_ != nullptr) {
use_int8_ = true;
}
network_ = builder_->createNetworkV2(flags);
#else
builder_->setMaxBatchSize(batch_size_);
builder_->setMaxWorkspaceSize(max_workspace_size_);
builder_->setFp16Mode(use_fp16_);
network_ = builder_->createNetwork();
#endif
network_ = builder_->createNetworkV2(0U);
}

nvinfer1::DataType DLDataType2NVDataType(DLDataType data_type) {
Expand All @@ -87,10 +75,7 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
for (size_t i = 0; i < shapes.size(); ++i) {
const std::string name = node_name + "_" + std::to_string(i);
auto shape = shapes[i];
// Remove batch dim when not in explicit batch mode.
if (use_implicit_batch_ && shape.size() > 1) {
shape.erase(shape.begin());
}
// TensorRT 10 is always explicit-batch: keep the full shape including the batch dimension.
nvinfer1::Dims dims = VectorToTrtDims(shape);
auto input_tensor = network_->addInput(name.c_str(), DLDataType2NVDataType(dtypes[i]), dims);
node_output_map_[nid].push_back(TensorRTOpInput(input_tensor));
Expand Down Expand Up @@ -168,11 +153,10 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) {
}

TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
// Process graph to create INetworkDefinition.
// Build engine.
#if TRT_VERSION_GE(6, 0, 1)
// Build engine.
config_ = builder_->createBuilderConfig();
config_->setMaxWorkspaceSize(max_workspace_size_);
// TensorRT 10 replaced IBuilderConfig::setMaxWorkspaceSize with a tunable memory pool.
config_->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_);
if (use_fp16_) {
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
}
Expand All @@ -184,40 +168,48 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... ";
}

// Add profiles.
if (!use_implicit_batch_) {
auto profile = builder_->createOptimizationProfile();
for (int i = 0; i < network_->getNbInputs(); ++i) {
auto name = network_->getInput(i)->getName();
const uint32_t entry_id = entry_id_map_[name];
std::vector<int64_t> shape(data_entry_[entry_id]->shape,
data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim);
auto dims = VectorToTrtDims(shape);
// Every network is explicit-batch in TRT10, so always add an optimization profile that pins each
// input to its concrete shape (with a minimum batch of 1 for dynamic batch dimensions).
auto profile = builder_->createOptimizationProfile();
for (int i = 0; i < network_->getNbInputs(); ++i) {
auto name = network_->getInput(i)->getName();
const uint32_t entry_id = entry_id_map_[name];
std::vector<int64_t> shape(data_entry_[entry_id]->shape,
data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim);
auto dims = VectorToTrtDims(shape);

profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims);
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims);
// Set minimum batch size to 1 when dynamic batching is used.
if (network_->getInput(i)->getDimensions().nbDims >= 1 &&
network_->getInput(i)->getDimensions().d[0] == -1) {
dims.d[0] = 1;
}
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims);
profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims);
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims);
// The network inputs are built with static shapes, so the profile must match them exactly; only
// lower kMIN for a genuinely dynamic (-1) leading dimension.
if (network_->getInput(i)->getDimensions().nbDims >= 1 &&
network_->getInput(i)->getDimensions().d[0] == -1) {
dims.d[0] = 1;
}
config_->addOptimizationProfile(profile);
profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims);
}
nvinfer1::ICudaEngine* engine = builder_->buildEngineWithConfig(*network_, *config_);
#else
nvinfer1::ICudaEngine* engine = builder_->buildCudaEngine(*network_);
#endif
TVM_FFI_ICHECK_EQ(engine->getNbBindings(),
network_input_names_.size() + network_output_names_.size());
config_->addOptimizationProfile(profile);

// TensorRT 10 removed buildEngineWithConfig; build a serialized engine and deserialize it through
// an IRuntime that is kept alive alongside the engine (TensorRTEngineAndContext::runtime).
nvinfer1::IHostMemory* plan = builder_->buildSerializedNetwork(*network_, *config_);
TVM_FFI_ICHECK(plan) << "Failed to build TensorRT serialized network.";
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(*trt_logger_);
nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(plan->data(), plan->size());
delete plan;
if (engine == nullptr) {
delete runtime;
TVM_FFI_THROW(InternalError) << "Failed to deserialize the TensorRT engine.";
}
Comment thread
tlopex marked this conversation as resolved.
TVM_FFI_ICHECK_EQ(
engine->getNbIOTensors(),
static_cast<int32_t>(network_input_names_.size() + network_output_names_.size()));
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
CleanUp();

TVM_FFI_ICHECK(engine);
TVM_FFI_ICHECK(context);

return {engine, context, network_input_names_, network_output_names_};
return {runtime, engine, context, network_input_names_, network_output_names_};
}

nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
Expand All @@ -236,46 +228,35 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
}
weight.count = count;
weight.values = new float[count];
TVM_FFI_ICHECK_EQ(TVMTensorCopyToBytes(const_cast<DLTensor*>(dptr),
const_cast<void*>(weight.values), weight_bytes),
0)
<< TVMGetLastError();
// Tensor::CopyToBytes throws on failure (the old C API TVMTensorCopyToBytes/TVMGetLastError
// were removed during the tvm-ffi refactor).
Tensor::CopyToBytes(dptr, const_cast<void*>(weight.values), weight_bytes);
trt_weights_.push_back(weight);
return weight;
}

nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& input) {
if (input.type == kTensor) return input.tensor;
auto shape = input.weight_shape;
// Remove batch dim when not in explicit batch mode.
// Example:
// x = dims (1, 32, 224, 224) which becomes TRT Dims (32, 224, 224)
// y = dims (1, 32)
// z = add(x, y)
// y needs to have TRT dims (32,), otherwise broadcasting will result in z having
// TRT Dims(1, 32, 224, 224) when it should be (32, 224, 224).
if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) {
shape.erase(shape.begin());
}
// TensorRT 10 is always explicit-batch, so the constant keeps its full shape.
return network_->addConstant(VectorToTrtDims(shape), input.weight)->getOutput(0);
}

void TensorRTBuilder::CleanUp() {
// TensorRT 10 removed obj->destroy(); objects are released with the delete operator.
VLOG(1) << "Destroying TensorRT network";
TVM_FFI_ICHECK(network_);
network_->destroy();
delete network_;
network_ = nullptr;

#if TRT_VERSION_GE(6, 0, 1)
VLOG(1) << "Destroying TensorRT config";
TVM_FFI_ICHECK(config_);
config_->destroy();
delete config_;
config_ = nullptr;
#endif

VLOG(1) << "Destroying TensorRT builder";
TVM_FFI_ICHECK(builder_);
builder_->destroy();
delete builder_;
builder_ = nullptr;

VLOG(1) << "Destroying TensorRT weights";
Expand Down
18 changes: 7 additions & 11 deletions src/runtime/extra/contrib/tensorrt/tensorrt_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
* perform inference.
*/
struct TensorRTEngineAndContext {
// TensorRT 10 builds a serialized engine which is then deserialized through an IRuntime. The
// runtime must outlive the engine it produced, so it is owned alongside the engine/context.
nvinfer1::IRuntime* runtime = nullptr;
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IExecutionContext* context = nullptr;
std::vector<std::string> inputs;
Expand All @@ -67,12 +70,10 @@ class TensorRTBuilder {
* \brief Create TensorRT builder.
* \param logger TensorRT logger to use for errors and warnings.
* \param max_workspace_size Workspace size parameter for TensorRT engine build phase.
* \param use_implicit_batch Whether to use implicit batch mode (default)
* \param use_fp16 Whether to automatically convert a model to fp16
* \param batch_size If use_implicit_batch,
*/
TensorRTBuilder(TensorRTLogger* logger, const std::vector<const DLTensor*>& data_entry,
size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, int batch_size,
size_t max_workspace_size, bool use_fp16,
nvinfer1::IInt8Calibrator* calibrator = nullptr);

/*!
Expand Down Expand Up @@ -124,13 +125,14 @@ class TensorRTBuilder {
/*! \brief Maps a node to its outputs. */
std::unordered_map<int, std::vector<TensorRTOpInput>> node_output_map_;

/*! \brief TensorRT logger, used to create the builder and the deserialization runtime. */
TensorRTLogger* trt_logger_ = nullptr;

/*! \brief TensorRT builder. */
nvinfer1::IBuilder* builder_ = nullptr;

#if TRT_VERSION_GE(6, 0, 1)
/*! \brief TensorRT builder config. */
nvinfer1::IBuilderConfig* config_ = nullptr;
#endif

/*! \brief TensorRT network definition. */
nvinfer1::INetworkDefinition* network_ = nullptr;
Expand All @@ -147,18 +149,12 @@ class TensorRTBuilder {
/*! \brief Max workspace size in bytes for TRT. */
size_t max_workspace_size_;

/*! \brief Whether to use implicit batch mode. */
bool use_implicit_batch_;

/*! \brief Whether to automatically convert model to 16-bit floating point precision. */
bool use_fp16_;

/*! \brief whether to automatically convert model to int8 precision */
bool use_int8_;

/*! \brief Batch size to optimize for. */
int batch_size_;

/*! \brief Input names. */
std::vector<std::string> network_input_names_;

Expand Down
5 changes: 4 additions & 1 deletion src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 {
const int num_inputs = data_sizes_[0].size();
buffers_.assign(num_inputs, nullptr);
for (int i = 0; i < num_inputs; ++i) {
TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float)));
// data_sizes_ holds the per-sample element count; getBatch() copies a full batch
// (batch_size_ * per-sample) into each buffer, so the device buffer must be sized to match.
TVM_FFI_CHECK_CUDA_ERROR(
cudaMalloc(&buffers_[i], batch_size_ * data_sizes_[0][i] * sizeof(float)));
}
}
};
Expand Down
Loading
Loading