Skip to content

Commit de1e14b

Browse files
committed
test: add dtype_dispatch tests
1 parent 8604d2d commit de1e14b

6 files changed

Lines changed: 544 additions & 58 deletions

File tree

CMakeLists.txt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,38 @@ link_infini_train_exe(test_precision_check)
204204
add_executable(test_lora test/lora/test_lora.cc)
205205
link_infini_train_exe(test_lora)
206206

207+
add_executable(test_scalar test/scalar/test_scalar.cc)
208+
link_infini_train_exe(test_scalar)
209+
210+
add_executable(test_dtype_dispatch test/dispatch/test_dtype_dispatch.cc)
211+
link_infini_train_exe(test_dtype_dispatch)
212+
213+
# Negative compile test: missing dtype registration must fail at compile time.
214+
set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE
215+
${PROJECT_SOURCE_DIR}/test/dispatch/test_dtype_dispatch_compile_fail.cc)
216+
217+
try_compile(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED
218+
${CMAKE_BINARY_DIR}/CMakeFiles/try_compile_dtype_dispatch_missing_map
219+
SOURCES ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}
220+
CMAKE_FLAGS
221+
"-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}"
222+
"-DCMAKE_CXX_STANDARD_REQUIRED=ON"
223+
"-DCMAKE_CXX_EXTENSIONS=OFF"
224+
"-DCMAKE_CXX_FLAGS=-I${PROJECT_SOURCE_DIR}"
225+
OUTPUT_VARIABLE DTYPE_DISPATCH_TRY_COMPILE_OUTPUT
226+
)
227+
228+
if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED)
229+
message(FATAL_ERROR
230+
"dtype dispatch compile-fail test unexpectedly succeeded.\n"
231+
"Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n"
232+
"Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
233+
endif()
234+
235+
add_custom_target(test_dtype_dispatch_compile_fail
236+
COMMAND ${CMAKE_COMMAND} -E echo
237+
"dtype dispatch compile-fail check passed (missing dtype registration correctly fails to compile)."
238+
VERBATIM
239+
)
240+
241+
add_dependencies(test_dtype_dispatch test_dtype_dispatch_compile_fail)

docs/dtype_registry_design.md

Lines changed: 12 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,23 @@
11
# Low-Precision DType Abstraction & Backend Registration Design
22
统一低精度类型抽象与后端显式注册 pr:https://github.com/InfiniTensor/InfiniTrain/pull/114
33

4-
## 1. 背景与动机
4+
## 1. 背景与目标
55

6-
InfiniTrain 在引入 BF16 / FP16 之前,框架层并没有低精度类型的统一抽象,所有关于 16-bit 浮点的语义都直接绑定在 CUDA 原生类型 `__half` / `__nv_bfloat16` 上。这
7-
导致几个问题:
6+
InfiniTrain 在引入 BF16 / FP16 之前,框架层并没有低精度类型的统一抽象,所有 16-bit 浮点语义都直接绑定到后端原生类型:CUDA 侧使用 __half / __nv_bfloat16,CPU 侧则直接使用 uint16_t。这种设计带来了几个问题:
87

98
1. **框架代码被 `#ifdef USE_CUDA` 污染。**
10-
`infini_train/include/datatype.h``infini_train/src/nn/init.cc` 等通用模块都需要
11-
写出 `#ifdef USE_CUDA … #else …` 来在「有 CUDA」和「没有 CUDA」两个版本之间
12-
切换 16-bit 类型映射;非 CUDA 路径只能退化成 `uint16_t`,而 `uint16_t` 又会与
9+
`infini_train/include/datatype.h``infini_train/src/nn/init.cc` 等通用模块都需要写出 `#ifdef USE_CUDA … #else …` 来在「有 CUDA」和「没有 CUDA」两个版本之间切换 16-bit 类型映射;非 CUDA 路径只能退化成 `uint16_t`,而 `uint16_t` 又会与
1310
`kUINT16` 的反向映射产生歧义。
1411
2. **`TypeMap<DType>` 是「全后端共享」的单点表。**
15-
`TypeMap` 把所有标量类型直接映射到 C++ 类型。CPU 与 CUDA 共享同一个表,
16-
意味着不可能在不同后端把 `kFLOAT16` 映射到不同的本地标量;要扩展新硬件必须改框架头文件。
12+
`TypeMap` 把所有标量类型直接映射到 C++ 类型。CPU 与 CUDA 共享同一个表,意味着不可能在不同后端把 `kFLOAT16` 映射到不同的本地标量;要扩展新硬件必须改框架头文件。
1713
3. **类型提升耦合具体后端类型。**
18-
旧的 `WidestType_t<T1, T2>` 在 C++ 模板层面做提升,需要每个调用点先 dispatch 出
19-
一对具体的标量类型(例如 `nv_bfloat16` + `float`),再交给元函数做选择。这把
20-
「类型提升」这一纯 dtype 级别的逻辑跟「后端具体标量」捆死了。
14+
旧的 `WidestType_t<T1, T2>` 在 C++ 模板层面做提升,需要每个调用点先 dispatch 出一对具体的标量类型(例如 `nv_bfloat16` + `float`),再交给元函数做选择。这把「类型提升」这一纯 dtype 级别的逻辑跟「后端具体标量」捆死了。
2115
4. **静默 fallback 容易掩盖错误。**
22-
一旦某个后端忘记注册 BF16/FP16,旧实现会沉默地走到 `uint16_t` 路径,得到一个
23-
语义错误的内核,而不是显式报错。
16+
一旦某个后端忘记定义低精度类型,旧实现默认映射到 `uint16_t`,会得到一个语义错误的内核,而不是显式报错。
2417

2518
本工作的目标是:
2619

27-
> **把 FP16/BF16 抽象成框架级类型**,让框架代码不再直接接触任何后端原生
28-
> 16-bit 类型;同时把后端 dtype → 本地标量的映射改成**显式注册**机制,未注册的类型在编译期就被拦截。
20+
> **抽象出框架级通用低精度类型 FP16/BF16**,让框架代码不再直接依赖任何后端原生 16-bit 类型;同时把框架 [DataType -> 后端 C++ 类型] 的映射改为**显式注册**机制,未注册的类型如果被实例化,会在编译期被拦截报错。
2921
3022
## 2. Design In One Diagram
3123

@@ -46,8 +38,8 @@ kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
4638

4739
要点:
4840

49-
- 框架层不提供任何「DataType → C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
50-
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**只有后端显式特化并完成注册的 dtype 组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。
41+
- 框架层不提供任何「DataType → 后端 C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
42+
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**只有后端显式特化并完成注册的组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。
5143

5244
## 3. Core API
5345

@@ -59,51 +51,13 @@ kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
5951
| `INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 一次性注册 10 个非低精度 dtype(`kUINT8…kFLOAT64`)到对应 C++ 标量。 |
6052
| `DispatchCpuFunc / DispatchCudaFunc<AllowedDTypes...>` | `src/core/runtime/{cpu,cuda}/{cpu,cuda}_dispatch.h` | 后端 dispatch 入口,底层转发到 `DispatchByTypeMap<TypeMap, AllowedDTypes...>`|
6153

62-
## 4. Scalar:框架层标量载体
63-
64-
`BackendTypeMap` 解决「DataType → 后端 C++ 类型」,但框架 API 还需要一种
65-
**DataType 无关** 的方式接收标量参数:目标 tensor 的 DataType 运行期才确定,API 不可能
66-
为每种数值类型都写重载,更不能把后端原生类型暴露给调用方。
67-
68-
为此引入 `Scalar`[scalar.h](../infini_train/include/scalar.h)):
69-
70-
- 固定存储:`double / int64_t / uint64_t` + `Kind` tag(`kBool / kDouble / kInt64 / kUInt64`)。
71-
- 隐式构造覆盖所有框架标量:整数按符号分入 `kInt64 / kUInt64`,全部浮点(含 `FP16 / BF16`)归一到 `kDouble``bool` 独立。
72-
- 唯一出口 `Scalar::to<T>()`,通过 `common::cpu::Cast<T>` 把存储值转换到 dispatch 选出的后端标量类型。
73-
74-
与其它抽象的边界:`BackendTypeMap` 管「DataType → 后端 C++ 类型」,`PromoteDataTypes`
75-
「DataType → DataType」,`Scalar` 管「数值 → 后端 C++ 类型」,三者正交;`Scalar` 本身不参与类型提升决策。
76-
77-
### 4.1 使用模式
78-
79-
`Tensor::Fill(Scalar)` 是这套抽象的第一个落地点。kernel 侧使用模式如下:
80-
81-
```cpp
82-
// kernels/cpu/fill.cc
83-
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
84-
core::cpu::DispatchCpuFunc<INFINI_ALL_TYPES>(
85-
tensor->Dtype(),
86-
[=]<typename T>() {
87-
auto data = reinterpret_cast<T *>(tensor->DataPtr());
88-
const T v = scalar.to<T>(); // Scalar 在此完成「数值 → 后端 C++ 类型」映射
89-
std::fill(data, data + tensor->NumElements(), v);
90-
},
91-
"CPU Fill");
92-
}
93-
```
94-
95-
`DispatchCpuFunc` 经 `BackendTypeMap` 把 `DataType` 解析为 `T`;`Scalar::to<T>()`
96-
把用户传入值转换到该 `T`。
97-
98-
## 5. How To Add A New Backend
54+
## 4. How To Add A New Backend
9955

10056
按以下清单操作,**不需要**修改 `infini_train/include/` 下的任何框架头文件,也不需要 `#ifdef`
10157

10258
1. 在后端的 `*_dispatch.h` 里 include `core/backend_type_map.h``dtype_dispatch.h`
10359
2. 调用 `INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kXxx)` 注册 10 个标准 dtype。
104-
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>`
105-
指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在
106-
编译期触发 `static_assert`。
60+
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>` 指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在编译期触发 `static_assert`
10761
4. 定义 `XxxTypeMap<DType>` 转发/继承到 `BackendTypeMap<kXxx, DType>`
10862
5. 提供 `DispatchXxxFunc` 入口,转发到 `DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>`
10963

@@ -134,7 +88,7 @@ auto DispatchXxxFunc(DataType dtype, Functor &&f, std::string_view ctx = "", Arg
13488
} // namespace infini_train::core::xxx
13589
```
13690
137-
## 6. Failure Modes
91+
## 5. Failure Modes
13892
13993
| 情形 | 表现 |
14094
| --- | --- |

infini_train/include/scalar.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,31 @@ struct Scalar {
3131
Scalar(FP16 v) : kind(Kind::kDouble), d(static_cast<float>(v)) {}
3232
Scalar(BF16 v) : kind(Kind::kDouble), d(static_cast<float>(v)) {}
3333

34+
// TODO(dcj): Scalar::to<T>() should remain a framework-level conversion API
35+
// and should not directly target backend-native types such as __nv_bfloat16
36+
// or __half.
37+
//
38+
// Today to<T>() delegates to common::cpu::Cast, which only has explicit
39+
// semantics for framework scalar types (e.g. FP16/BF16). When T is a
40+
// backend-native half type, it falls back to raw static_cast, which happens
41+
// to compile on CUDA (via implicit constructors) but is backend-dependent
42+
// and may fail on other platforms (e.g. MACA).
43+
//
44+
// More importantly, this creates inconsistent rounding paths:
45+
// - to<BF16>(): double -> float -> bf16
46+
// - to<__nv_bfloat16>(): double -> bf16
47+
// The two paths may yield different results due to double rounding.
48+
// See `test/dtype/test_scalar.cc` (`TestToHalfPrecisionConversions`) for
49+
// a similar example.
50+
//
51+
// Planned fix:
52+
// 1) keep Scalar::to<T>() restricted to framework/common scalar types
53+
// 2) introduce a standalone convert<To, From> utility for common
54+
// conversion semantics
55+
// 3) let kernel/backend code use a backend-specific scalar_cast<T>
56+
// helper for native types, routing half-precision conversions
57+
// through float to guarantee consistent two-step rounding on all
58+
// backends.
3459
template <typename T> T to() const {
3560
switch (kind) {
3661
case Kind::kBool:

test/dtype/test_dtype_dispatch.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <cstdlib>
2+
#include <iostream>
3+
#include <string>
4+
#include <type_traits>
5+
6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/datatype.h"
9+
#include "infini_train/include/dtype_dispatch.h"
10+
11+
#include "infini_train/src/core/runtime/cpu/cpu_dispatch.h"
12+
13+
using namespace infini_train;
14+
15+
// ============================================================================
16+
// Test 1: HasMappedType_v intercepts backends missing FP16 / BF16
17+
// ============================================================================
18+
19+
// A backend TypeMap that only registers kFLOAT32 — FP16 / BF16 are absent.
20+
template <DataType DType> struct LowPrecisionAbsentTypeMap;
21+
22+
template <> struct LowPrecisionAbsentTypeMap<DataType::kFLOAT32> {
23+
using type = float;
24+
};
25+
26+
static_assert(HasMappedType_v<LowPrecisionAbsentTypeMap, DataType::kFLOAT32>,
27+
"sanity: registered dtype must be detected as present");
28+
static_assert(!HasMappedType_v<LowPrecisionAbsentTypeMap, DataType::kFLOAT16>,
29+
"unregistered kFLOAT16 must be intercepted by HasMappedType_v");
30+
static_assert(!HasMappedType_v<LowPrecisionAbsentTypeMap, DataType::kBFLOAT16>,
31+
"unregistered kBFLOAT16 must be intercepted by HasMappedType_v");
32+
33+
// ============================================================================
34+
// Test 2: CpuTypeMap resolves FP16 / BF16 to framework scalar types
35+
// ============================================================================
36+
37+
static_assert(std::is_same_v<MappedType_t<core::cpu::CpuTypeMap, DataType::kFLOAT16>, FP16>,
38+
"CpuTypeMap<kFLOAT16> must resolve to framework FP16");
39+
static_assert(std::is_same_v<MappedType_t<core::cpu::CpuTypeMap, DataType::kBFLOAT16>, BF16>,
40+
"CpuTypeMap<kBFLOAT16> must resolve to framework BF16");
41+
42+
// ============================================================================
43+
// Test 3: Runtime dispatch of kFLOAT16 / kBFLOAT16
44+
// ============================================================================
45+
46+
void TestRuntimeDispatchLowPrecision() {
47+
std::cout << "\n=== Test 3: Runtime dispatch of kFLOAT16 / kBFLOAT16 ===" << std::endl;
48+
49+
// kFLOAT16 must dispatch to framework FP16
50+
bool called_fp16 = false;
51+
core::cpu::DispatchCpuFunc<DataType::kFLOAT16, DataType::kBFLOAT16>(
52+
DataType::kFLOAT16,
53+
[&called_fp16]<typename T>() {
54+
if constexpr (std::is_same_v<T, FP16>) {
55+
called_fp16 = true;
56+
}
57+
},
58+
"dispatch kFLOAT16");
59+
CHECK(called_fp16) << "DispatchCpuFunc did not invoke functor for kFLOAT16";
60+
61+
// kBFLOAT16 must dispatch to framework BF16
62+
bool called_bf16 = false;
63+
core::cpu::DispatchCpuFunc<DataType::kFLOAT16, DataType::kBFLOAT16>(
64+
DataType::kBFLOAT16,
65+
[&called_bf16]<typename T>() {
66+
if constexpr (std::is_same_v<T, BF16>) {
67+
called_bf16 = true;
68+
}
69+
},
70+
"dispatch kBFLOAT16");
71+
CHECK(called_bf16) << "DispatchCpuFunc did not invoke functor for kBFLOAT16";
72+
73+
std::cout << "Low-precision dispatch OK." << std::endl;
74+
}
75+
76+
// ============================================================================
77+
// Test 4: Runtime dispatch of a low-precision dtype outside AllowedDTypes
78+
// must fatal
79+
// ============================================================================
80+
81+
// Sub-process entry: tries to dispatch kFLOAT16 with only kFLOAT32 allowed.
82+
void TriggerRuntimeUnsupportedLowPrecisionFatal() {
83+
core::cpu::DispatchCpuFunc<DataType::kFLOAT32>(
84+
DataType::kFLOAT16,
85+
[]<typename T>() { (void)sizeof(T); },
86+
"intercept kFLOAT16 when only kFLOAT32 is allowed");
87+
}
88+
89+
void TestRuntimeInterceptLowPrecision(const char *argv0) {
90+
std::cout << "\n=== Test 4: Runtime intercept of kFLOAT16 outside AllowedDTypes ===" << std::endl;
91+
const std::string cmd = std::string(argv0) + " --expect-runtime-fatal > /dev/null 2>&1";
92+
const int status = std::system(cmd.c_str());
93+
CHECK_NE(status, 0) << "Expected non-zero exit when dispatching an unallowed low-precision dtype";
94+
std::cout << "Low-precision runtime intercept OK." << std::endl;
95+
}
96+
97+
// ============================================================================
98+
// Main
99+
// ============================================================================
100+
101+
int main(int argc, char *argv[]) {
102+
google::InitGoogleLogging(argv[0]);
103+
104+
if (argc > 1 && std::string(argv[1]) == "--expect-runtime-fatal") {
105+
TriggerRuntimeUnsupportedLowPrecisionFatal();
106+
return 0;
107+
}
108+
109+
std::cout << "========================================" << std::endl;
110+
std::cout << " Low-precision Dtype Dispatch Test Suite" << std::endl;
111+
std::cout << "========================================" << std::endl;
112+
113+
std::cout << "Compile-time checks: PASSED" << std::endl;
114+
115+
TestRuntimeDispatchLowPrecision();
116+
TestRuntimeInterceptLowPrecision(argv[0]);
117+
118+
std::cout << "\nAll low-precision dtype dispatch tests passed." << std::endl;
119+
return 0;
120+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "infini_train/include/datatype.h"
2+
#include "infini_train/include/dtype_dispatch.h"
3+
4+
using namespace infini_train;
5+
6+
// ============================================================================
7+
// Compile-fail: dispatching an unregistered low-precision dtype must be
8+
// intercepted at compile time
9+
// ============================================================================
10+
11+
// Models a backend that has registered standard floating types but has NOT
12+
// yet provided a mapping for the low-precision dtypes FP16 / BF16.
13+
template <DataType DType> struct LowPrecisionMissingTypeMap;
14+
15+
template <> struct LowPrecisionMissingTypeMap<DataType::kFLOAT32> {
16+
using type = float;
17+
};
18+
19+
int main() {
20+
// Dispatching kFLOAT16 through LowPrecisionMissingTypeMap must trigger the
21+
// static_assert inside DispatchByTypeMap, failing this translation unit
22+
// before MappedType_t<TypeMap, kFLOAT16> is ever instantiated.
23+
DispatchByTypeMap<LowPrecisionMissingTypeMap, DataType::kFLOAT16>(
24+
DataType::kFLOAT16,
25+
[]<typename T>() { (void)sizeof(T); },
26+
"compile-fail: unregistered low-precision dtype");
27+
return 0;
28+
}

0 commit comments

Comments
 (0)