Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/target/detect_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _detect_cuda(dev: Device) -> Target:
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
"arch": "sm_" + dev.compute_version.replace(".", ""),
"enable_fast_math": False,
}
)

Expand Down
5 changes: 4 additions & 1 deletion python/tvm/target/tag_registry/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def _register_cuda_tag(name, arch, shared_mem=49152, regs=65536, **extra):
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"registers_per_block": regs,
# Default to disable fast math
"enable_fast_math": False,
}
config.update(extra)
register_tag(name, config)


def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536):
def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536, enable_fast_math=False):
register_tag(
name,
{
Expand All @@ -49,6 +51,7 @@ def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536):
"mcpu": mcpu,
"num-cores": num_cores,
},
"enable_fast_math": enable_fast_math,
},
)

Expand Down
30 changes: 22 additions & 8 deletions src/target/cuda/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,37 +174,46 @@ TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.exp")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.exp2")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.exp10")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.erf")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.log")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.log2")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.log10")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.tan")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMathTan>);
// Now the fast math version of tan and the default version of tan are same.
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMathTan>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.cos")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.cosh")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.sin")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.sinh")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
Expand All @@ -213,12 +222,17 @@ TVM_REGISTER_OP("tirx.atan")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.tanh")
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.sqrt")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.rsqrt")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.pow")
.set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);

TVM_REGISTER_OP("tirx.popcount")
Expand Down
9 changes: 9 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ ffi::Map<ffi::String, ffi::Any> UpdateCUDAAttrs(ffi::Map<ffi::String, ffi::Any>
}
target.Set("arch", ffi::String("sm_") + std::to_string(archInt));
}
// Update enable_fast_math
if (target.count("enable_fast_math")) {
// If enable_fast_math has been specified, validate that enable_fast_math is a bool
Downcast<bool>(target.at("enable_fast_math"));
} else {
// If enable_fast_math has not been specified, default to false
target.Set("enable_fast_math", false);
}
return target;
}

Expand Down Expand Up @@ -367,6 +375,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<int64_t>("l2_cache_size_bytes")
.add_attr_option<int64_t>("max_num_threads",
refl::DefaultValue(1024)) // TODO(@zxybazh): deprecate it
.add_attr_option<bool>("enable_fast_math")
.set_default_keys({"cuda", "gpu"})
.set_target_canonicalizer(UpdateCUDAAttrs);

Expand Down
20 changes: 14 additions & 6 deletions src/tirx/transform/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,21 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt_;
using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;

IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt) : IRMutatorWithAnalyzer(analyzer) {
std::string target = tgt->kind->name;
ffi::String mtriple = tgt->GetAttr<ffi::String>("mtriple").value_or("");

std::vector<std::string> patterns;
// For CUDA targets, we need to add the fast math patterns if enable_fast_math is true.
// The priority of the fast math patterns is higher than the normal patterns.
bool is_fast_math = tgt->GetAttr<bool>("enable_fast_math").value_or(false);
if (is_fast_math) {
patterns.push_back(target + ".fastmath.FLowerIntrinsic");
patterns.push_back(target + ".fastmath.FLegalize");
}
patterns.push_back(target + ".FLowerIntrinsic");
patterns.push_back(target + ".FLegalize");

bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
patterns.push_back(target + ".aarch64.FLowerIntrinsic");
Expand Down Expand Up @@ -354,7 +364,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {

Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target)(std::move(stmt));
return IntrinInjecter(&analyzer, Target(ffi::String(target)))(std::move(stmt));
}

namespace transform {
Expand All @@ -365,9 +375,7 @@ Pass LowerIntrin() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer;
auto mtriple = target.value()->GetAttr<ffi::String>("mtriple", "");
n->body =
IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body));
n->body = IntrinInjecter(&analyzer, target.value())(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {});
Expand Down
Loading
Loading