Skip to content

Commit 8e3ea54

Browse files
committed
fix(rng): use fixed seed for deterministic LoRA init
Replace std::random_device with 42 + omp_get_thread_num() to ensure reproducible LoRA initialization across runs.
1 parent 3d61e10 commit 8e3ea54

2 files changed

Lines changed: 7 additions & 8 deletions

File tree

example/llama3/main.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,6 @@ void Train(const nn::parallel::Rank &rank) {
175175
model = std::make_shared<LLaMA3>(model_config);
176176
}
177177

178-
model->To(device);
179-
180-
utils::PrecisionChecker::BuildNameMap(model.get());
181-
182178
// Apply LoRA using GetLoRAModel (in-place injection)
183179
bool lora_enabled = FLAGS_lora_rank > 0;
184180
if (lora_enabled) {
@@ -198,6 +194,10 @@ void Train(const nn::parallel::Rank &rank) {
198194
nn::lora::PrintLoRASummary(model, rank.GlobalRank());
199195
}
200196

197+
model->To(device);
198+
199+
utils::PrecisionChecker::BuildNameMap(model.get());
200+
201201
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";
202202

203203
DataType dtype;

infini_train/src/nn/init.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323
namespace infini_train::nn::init {
2424
namespace {
25-
static std::random_device rd;
26-
static std::mt19937 gen(rd());
25+
static std::mt19937 gen(42);
2726
} // namespace
2827

2928
std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean, float std,
@@ -34,7 +33,7 @@ std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean
3433
#ifdef USE_OMP
3534
#pragma omp parallel
3635
{
37-
std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num());
36+
std::mt19937 local_gen(42 + omp_get_thread_num());
3837
std::normal_distribution<float> local_dis(mean, std);
3938
#pragma omp for
4039
for (int i = 0; i < buffer.size(); ++i) {
@@ -126,7 +125,7 @@ std::shared_ptr<Tensor> Uniform(const std::shared_ptr<Tensor> &tensor, float a,
126125
#ifdef USE_OMP
127126
#pragma omp parallel
128127
{
129-
std::mt19937 local_gen(std::random_device{}() + omp_get_thread_num());
128+
std::mt19937 local_gen(42 + omp_get_thread_num());
130129
std::uniform_real_distribution<float> local_dis(a, b);
131130
#pragma omp for
132131
for (int i = 0; i < buffer.size(); ++i) {

0 commit comments

Comments
 (0)