Skip to content
35 changes: 35 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,38 @@ link_infini_train_exe(test_precision_check)
add_executable(test_lora test/lora/test_lora.cc)
link_infini_train_exe(test_lora)

add_executable(test_scalar test/scalar/test_scalar.cc)
link_infini_train_exe(test_scalar)

add_executable(test_dtype_dispatch test/dispatch/test_dtype_dispatch.cc)
link_infini_train_exe(test_dtype_dispatch)

# Negative compile test: missing dtype registration must fail at compile time.
set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE
${PROJECT_SOURCE_DIR}/test/dispatch/test_dtype_dispatch_compile_fail.cc)

try_compile(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED
${CMAKE_BINARY_DIR}/CMakeFiles/try_compile_dtype_dispatch_missing_map
SOURCES ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}
CMAKE_FLAGS
"-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}"
"-DCMAKE_CXX_STANDARD_REQUIRED=ON"
"-DCMAKE_CXX_EXTENSIONS=OFF"
"-DCMAKE_CXX_FLAGS=-I${PROJECT_SOURCE_DIR}"
OUTPUT_VARIABLE DTYPE_DISPATCH_TRY_COMPILE_OUTPUT
)

if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED)
message(FATAL_ERROR
"dtype dispatch compile-fail test unexpectedly succeeded.\n"
"Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n"
"Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
endif()

add_custom_target(test_dtype_dispatch_compile_fail
COMMAND ${CMAKE_COMMAND} -E echo
"dtype dispatch compile-fail check passed (missing dtype registration correctly fails to compile)."
VERBATIM
)

add_dependencies(test_dtype_dispatch test_dtype_dispatch_compile_fail)
185 changes: 185 additions & 0 deletions docs/device_guard_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Device Guard Design
Device 注册机制是 InfiniTrain 面向多硬件后端的统一运行时抽象与插件化接入基础设施。

## 1. 核心组件

InfiniTrain 的 device 机制由三类核心组件构成:

```C++
+-------------------+
| DeviceGuard | ← 对外 RAII 接口(public)
+-------------------+
|
v
+-------------------+
| DeviceGuardImpl | ← 后端抽象接口(virtual)
+-------------------+
^
|
+-------------------+
| DeviceGuardImpl |
| Registry | ← 全局注册表(singleton)
+-------------------+
```

其中 DeviceGuard 与 DeviceGuardImpl 的关系是:

| 组件 | 职责 |
| --------------- | ------------------------------------------------------------ |
| DeviceGuard | 管理 “当前在哪个 device 上” 的上下文语义(RAII),语义与 device index 绑定;负责 device 的保存/切换/恢复,并将具体 runtime 操作转发给对应的 DeviceGuardImpl。 |
| DeviceGuardImpl | 管理 “在该类 device 上如何执行 runtime 操作”,语义与 device type 绑定;对外提供 设备管理查询、stream、blas、同步、内存 等运行时能力接口。 |

### 1.1 DeviceGuardImpl:运行时能力抽象(对外暴露)

DeviceGuardImpl 是 InfiniTrain 中 device runtime 能力的统一抽象接口,并且是框架内部对外暴露的能力接口,封装了所有与 device 相关的行为(待补充 event 相关接口):

```C++
// ----------------------------------------------------------------------
// Device management
// ----------------------------------------------------------------------

virtual Device GetDevice() const = 0;

virtual void SetDevice(Device device) const;

virtual int8_t DeviceCount() const;

virtual Device::DeviceType Type() const = 0;

// ----------------------------------------------------------------------
// Stream management
// ----------------------------------------------------------------------

virtual Stream *GetStream(Device) const;

// ----------------------------------------------------------------------
// Synchronization
// ----------------------------------------------------------------------

virtual void SynchronizeDevice(Device) const;

virtual void SynchronizeStream(Stream *) const;

// ----------------------------------------------------------------------
// BLAS handle
// ----------------------------------------------------------------------

virtual BlasHandle *GetBlasHandle(Device) const;

// ----------------------------------------------------------------------
// Memory operations
// ----------------------------------------------------------------------

virtual void Malloc(void **dev_ptr, size_t size) = 0;

virtual void MallocAsync(void **dev_ptr, size_t size, Stream *stream);

virtual void Free(void *dev_ptr) = 0;

virtual void FreeAsync(void *dev_ptr, Stream *stream);

virtual void Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) = 0;

virtual void MemcpyAsync(void *dst, const void *src, size_t count, MemcpyKind kind, Stream *stream);

virtual void ResetMemPoolHighWatermarks(Device device) const;

virtual std::pair<size_t, size_t> GetMemPoolPeakMB(Device device) const;
```

### 1.2 DeviceGuard:RAII 前端接口

DeviceGuard 是设备上下文的 RAII 管理器,其职责严格限定为:

- 保存当前 device
- 切换到目标 device
- 在作用域结束时恢复原 device

DeviceGuard 不直接提供任何运行时能力接口。

使用示例:

```C++
{
DeviceGuard guard(Device(DeviceType::kCUDA, 1));
// 当前线程的 device 上下文被切换到 CUDA:1
// 所有 runtime 操作将发生在 CUDA:1
}
// 离开作用域后,自动恢复进入前的 device
```

### 1.3 DeviceGuardImplRegistry:全局注册表

`DeviceGuardImplRegistry`是 InfiniTrain 中用于管理 device runtime 后端实现的全局注册表,采用 singleton 模式,生命周期覆盖整个进程。

其核心职责是维护`DeviceType -> DeviceGuardImpl`的一对一映射关系:

```C++
std::unordered_map<Device::DeviceType, std::unique_ptr<DeviceGuardImpl>> impls_;
```

## 2. Runtime Capability 获取与使用范式

### 2.1 获取入口

```C++
DeviceGuardImpl* GetDeviceGuardImpl(Device::DeviceType type);
```

- 返回指定`DeviceType`的 DeviceGuardImpl
- 若未注册对应 backend,直接报错

### 2.2 推荐使用模式(标准范式)

```C++
auto device = tensor->GetDevice();
const int64_t num_elements = tensor->NumElements();
std::vector<float> buffer(num_elements);

{
// 1. 切换 device 上下文(RAII scope)
core::DeviceGuard guard(device);

// 2. 获取 runtime capability
auto* impl = core::GetDeviceGuardImpl(device.type());

// 3. 执行 runtime 操作
const core::MemcpyKind kind =
device.type() == Device::DeviceType::kCPU
? core::MemcpyKind::kD2D // CPU: host-host memcpy
: core::MemcpyKind::kH2D; // Device: host-device copy

impl->MemcpyAsync(
tensor->DataPtr(), // dst
buffer.data(), // src
num_elements * sizeof(float), // count
kind, // kind(说明:在 CPU backend 中,kD2D 对应普通 memcpy)
impl->GetStream(device) // stream
);
} // <-- DeviceGuard 在此处析构,device 上下文被恢复
```

## 3. Backend 注册机制(静态注册)

### 3.1 注册宏

```C++
#define INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(device_type, class_impl) \
static const bool __infini_train_device_guard_registered##__COUNTER__ = []() { \
infini_train::core::DeviceGuardImplRegistry::Instance().Register(device_type, std::make_unique<class_impl>()); \
return true; \
}();
```

采用静态变量 + lambda 在程序启动阶段完成注册。

### 3.2 使用示例(CUDA Backend)

```C++
class CudaGuardImpl : public DeviceGuardImpl {
...
};

INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCUDA, CudaGuardImpl)
```

78 changes: 78 additions & 0 deletions docs/dtype_registry_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Low-Precision DType Abstraction & Backend Registration Design
低精度 dtype 抽象是 InfiniTrain 面向多后端的统一类型语义与显式注册基础设施。

## 1. Design In One Diagram

```
framework code ──► FP16 / BF16 (datatype.h, 纯软件实现,提供基本转换操作)
PromoteDataTypes(DataType, DataType)

kernel code ──► DispatchCpuFunc / DispatchCudaFunc / DispatchXxxFunc
BackendTypeMap<Dev, DType> (主模板只声明不定义)
├─ kFLOAT16 / kBFLOAT16 → 后端在 *_dispatch.h 显式特化后注册
│ └── CUDA: __half / __nv_bfloat16
│ └── CPU : FP16 / BF16
└─ 其它 10 个标量 dtype 使用默认注册 → INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)
```

要点:

- 框架层不提供任何「DataType → 后端 C++ 类型」映射路径;所有具体类型绑定均在后端通过 `BackendTypeMap<Dev, DType>` 完成。
- `BackendTypeMap<Dev, DType>` 主模板**只声明不定义**,只有后端显式特化并完成注册的组合才允许参与 kernel dispatch;未注册组合会在模板实例化阶段被 `static_assert` 于编译期拦截。

## 2. Core API

| API | 位置 | 说明 |
| --- | --- | --- |
| `struct FP16 / BF16` | [datatype.h](../infini_train/include/datatype.h) | 16-bit 软件包装(IEEE-754 half / truncated bf16),承担框架身份、存储布局、fallback 转换;不承担后端高性能算术语义。 |
| `PromoteDataTypes(DataType, DataType)` | [datatype.h](../infini_train/include/datatype.h) | 纯枚举到枚举的类型提升。规则:FP16+BF16→FP32;浮点优先于整数;同类按字节宽取大。 |
| `BackendTypeMap<Dev, DType>` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 主模板**只声明不定义**;后端通过显式特化提供 `::type`。 |
| `INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV)` | [core/backend_type_map.h](../infini_train/include/core/backend_type_map.h) | 一次性注册 10 个非低精度 dtype(`kUINT8…kFLOAT64`)到对应 C++ 标量。 |
| `DispatchCpuFunc / DispatchCudaFunc<AllowedDTypes...>` | `src/core/runtime/{cpu,cuda}/{cpu,cuda}_dispatch.h` | 后端 dispatch 入口,底层转发到 `DispatchByTypeMap<TypeMap, AllowedDTypes...>`。 |

## 3. How To Add A New Backend

按以下清单操作,**不需要**修改 `infini_train/include/` 下的任何框架头文件,也不需要 `#ifdef`:

1. 在后端的 `*_dispatch.h` 里 include `core/backend_type_map.h` 与 `dtype_dispatch.h`。
2. 调用 `INFINI_REGISTER_STANDARD_BACKEND_TYPES(Device::DeviceType::kXxx)` 注册 10 个标准 dtype。
3. 若硬件支持低精度,显式特化 `BackendTypeMap<kXxx, kFLOAT16>` / `BackendTypeMap<kXxx, kBFLOAT16>` 指向后端本地 16-bit 标量类型;不支持则直接跳过,调用方一旦 dispatch 到未注册的 dtype 会在编译期触发 `static_assert`。
4. 定义 `XxxTypeMap<DType>` 转发/继承到 `BackendTypeMap<kXxx, DType>`。
5. 提供 `DispatchXxxFunc` 入口,转发到 `DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>`。

### Example

```cpp
// xxx_dispatch.h
#include "infini_train/include/core/backend_type_map.h"
#include "infini_train/include/dtype_dispatch.h"

namespace infini_train::core {
// 若硬件支持低精度,显式特化 FP16/BF16
template <> struct BackendTypeMap<Device::DeviceType::kXxx, DataType::kFLOAT16> { using type = xxx_half; };
template <> struct BackendTypeMap<Device::DeviceType::kXxx, DataType::kBFLOAT16> { using type = xxx_bfloat; };
} // namespace infini_train::core

INFINI_REGISTER_STANDARD_BACKEND_TYPES(infini_train::Device::DeviceType::kXxx)

namespace infini_train::core::xxx {
template <DataType DType>
struct XxxTypeMap : BackendTypeMap<Device::DeviceType::kXxx, DType> {};

template <DataType... AllowedDTypes, typename Functor, typename... Args>
auto DispatchXxxFunc(DataType dtype, Functor &&f, std::string_view ctx = "", Args &&...a) {
return DispatchByTypeMap<XxxTypeMap, AllowedDTypes...>(
dtype, std::forward<Functor>(f), ctx, std::forward<Args>(a)...);
}
} // namespace infini_train::core::xxx
```

## 4. Failure Modes

| 情形 | 表现 |
| --- | --- |
| 后端未注册某个 dtype(`BackendTypeMap<Dev, DType>` 无特化),但被 dispatch 命中 | 编译期 `static_assert` 触发,错误信息指向 `BackendTypeMap` 的显式注册要求。 |
| dispatch 的 dtype 不在调用点 `AllowedDTypes...` 白名单内 | 运行期 `LOG_UNSUPPORTED_DTYPE` 报错。 |
16 changes: 13 additions & 3 deletions infini_train/include/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@

#include "infini_train/include/datatype.h"

/**
* General Utility Macros
*/
#define EXPAND(X) X
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
// Basically an alternative for __VA_OPTS__ before C++20
// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h
#define WRAP(...) __VA_ARGS__
#define CAT(a, b) CAT_(a, b)
#define CAT_(a, b) a##b

#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
#define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__
#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \
LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \
+ kDataTypeToDesc.at(static_cast<infini_train::DataType>(dtype))))

inline std::vector<int64_t> ComputeStrides(const std::vector<int64_t> &dims) {
std::vector<int64_t> strides(dims.size(), 1);
Expand Down
35 changes: 28 additions & 7 deletions infini_train/include/common/cpu/common_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,41 @@
#include <type_traits>
#include <utility>

#include "infini_train/include/datatype.h"

namespace infini_train::common::cpu {

namespace detail {

// FP16/BF16 don't support implicit conversion, so we route through float.
template <typename DST, typename SRC> DST CastImpl(SRC &&x) {
using SrcBase = std::remove_cvref_t<SRC>;
if constexpr (std::is_same_v<DST, SrcBase>) {
return x;
} else if constexpr (std::is_same_v<DST, FP16> || std::is_same_v<DST, BF16>) {
// Destination is a framework 16-bit type: convert via float
return DST(static_cast<float>(std::forward<SRC>(x)));
} else if constexpr (std::is_same_v<SrcBase, FP16> || std::is_same_v<SrcBase, BF16>) {
// Source is a framework 16-bit type: widen to float first
return static_cast<DST>(static_cast<float>(x));
} else {
return static_cast<DST>(std::forward<SRC>(x));
}
}

} // namespace detail

/**
* Converts a value between arbitrary types. This offers perfect
* forwarding which preserves value categories (lvalues/rvalues)
* Converts a value between arbitrary types, including framework FP16/BF16.
*
* @tparam DST Destination type (deduced)
* @tparam DST Destination type
* @tparam SRC Source type (deduced)
* @param x Input value (preserves const/volatile and value category)
* @param x Input value
* @return Value converted to DST type
*/
template <typename DST, typename SRC> DST Cast(SRC &&x) {
static_assert(!std::is_reference_v<DST>, "Cast cannot return reference types");

// TODO(lzm): add cpu-version fp16 and bf16
return (DST)(std::forward<SRC>(x));
return detail::CastImpl<DST>(std::forward<SRC>(x));
}

} // namespace infini_train::common::cpu
Loading
Loading