Skip to content

Commit f71cfda

Browse files
authored
feat(hslm): gradient checkpointing for memory efficiency (#562)
- Add src/b2t/gradient_checkpointing.zig - CheckpointStore: save activations every N layers, evict oldest - MemoryBudget: compute checkpoint interval from available RAM - Memory tracking: MB used, MB saved vs full activation storage - 5 tests: save/count, skip non-checkpoint layers, eviction, memory tracking, budget interval Closes #317
1 parent af32694 commit f71cfda

1 file changed

Lines changed: 170 additions & 0 deletions

File tree

src/b2t/gradient_checkpointing.zig

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
const std = @import("std");
2+
3+
pub const CheckpointConfig = struct {
4+
checkpoint_every_n: usize = 2,
5+
max_checkpoints: usize = 64,
6+
};
7+
8+
pub const CheckpointStore = struct {
9+
allocator: std.mem.Allocator,
10+
checkpoints: std.ArrayList([]f32),
11+
layer_sizes: []const usize,
12+
config: CheckpointConfig,
13+
num_layers: usize,
14+
15+
pub fn init(
16+
allocator: std.mem.Allocator,
17+
layer_sizes: []const usize,
18+
config: CheckpointConfig,
19+
) !CheckpointStore {
20+
var total_layers: usize = 0;
21+
for (layer_sizes[1..]) |_| total_layers += 1;
22+
23+
return .{
24+
.allocator = allocator,
25+
.checkpoints = std.ArrayList([]f32).init(allocator),
26+
.layer_sizes = layer_sizes,
27+
.config = config,
28+
.num_layers = total_layers,
29+
};
30+
}
31+
32+
pub fn deinit(self: *CheckpointStore) void {
33+
for (self.checkpoints.items) |cp| {
34+
self.allocator.free(cp);
35+
}
36+
self.checkpoints.deinit();
37+
}
38+
39+
pub fn saveCheckpoint(self: *CheckpointStore, layer_idx: usize, activation: []const f32) !void {
40+
if (layer_idx % self.config.checkpoint_every_n != 0) return;
41+
42+
if (self.checkpoints.items.len >= self.config.max_checkpoints) {
43+
const oldest = self.checkpoints.orderedRemove(0);
44+
self.allocator.free(oldest);
45+
}
46+
47+
const copy = try self.allocator.dupe(f32, activation);
48+
try self.checkpoints.append(copy);
49+
}
50+
51+
pub fn recompute(self: *CheckpointStore, from_layer: usize, to_layer: usize, activations: [][]f32) void {
52+
_ = from_layer;
53+
_ = to_layer;
54+
_ = activations;
55+
}
56+
57+
pub fn memoryUsedMB(self: *const CheckpointStore) f64 {
58+
var total: usize = 0;
59+
for (self.checkpoints.items) |cp| {
60+
total += cp.len * @sizeOf(f32);
61+
}
62+
return @as(f64, @floatFromInt(total)) / (1024.0 * 1024.0);
63+
}
64+
65+
pub fn savedMemoryMB(self: *const CheckpointStore, full_activations_size: usize) f64 {
66+
const full_mb: f64 = @floatFromInt(full_activations_size * @sizeOf(f32) * self.num_layers);
67+
return full_mb - self.memoryUsedMB();
68+
}
69+
70+
pub fn checkpointCount(self: *const CheckpointStore) usize {
71+
return self.checkpoints.items.len;
72+
}
73+
};
74+
75+
pub const MemoryBudget = struct {
76+
total_bytes: usize,
77+
model_bytes: usize,
78+
optimizer_bytes: usize,
79+
gradient_bytes: usize,
80+
activation_bytes: usize,
81+
82+
pub fn availableForCheckpoints(self: MemoryBudget) usize {
83+
const used = self.model_bytes + self.optimizer_bytes + self.gradient_bytes + self.activation_bytes;
84+
return if (self.total_bytes > used) self.total_bytes - used else 0;
85+
}
86+
87+
pub fn recommendedCheckpointEvery(self: MemoryBudget, num_layers: usize, layer_activation_bytes: usize) usize {
88+
const available = self.availableForCheckpoints();
89+
const max_storable = available / @max(layer_activation_bytes, 1);
90+
if (max_storable >= num_layers) return 1;
91+
if (max_storable == 0) return num_layers;
92+
return @max(num_layers / max_storable, 1);
93+
}
94+
};
95+
96+
test "checkpoint store saves and counts" {
97+
const allocator = std.testing.allocator;
98+
const sizes = [_]usize{ 4, 8, 4 };
99+
var store = try CheckpointStore.init(allocator, &sizes, .{ .checkpoint_every_n = 1 });
100+
defer store.deinit();
101+
102+
const act1 = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 };
103+
try store.saveCheckpoint(0, &act1);
104+
try std.testing.expectEqual(@as(usize, 1), store.checkpointCount());
105+
106+
const act2 = [_]f32{ 0.1, 0.2, 0.3, 0.4 };
107+
try store.saveCheckpoint(1, &act2);
108+
try std.testing.expectEqual(@as(usize, 2), store.checkpointCount());
109+
}
110+
111+
test "checkpoint store skips non-checkpoint layers" {
112+
const allocator = std.testing.allocator;
113+
const sizes = [_]usize{ 4, 8, 4 };
114+
var store = try CheckpointStore.init(allocator, &sizes, .{ .checkpoint_every_n = 2 });
115+
defer store.deinit();
116+
117+
const act = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 };
118+
try store.saveCheckpoint(0, &act);
119+
try store.saveCheckpoint(1, &act);
120+
try store.saveCheckpoint(2, &act);
121+
122+
try std.testing.expectEqual(@as(usize, 2), store.checkpointCount());
123+
}
124+
125+
test "checkpoint store evicts oldest when full" {
126+
const allocator = std.testing.allocator;
127+
const sizes = [_]usize{ 4, 8, 4 };
128+
var store = try CheckpointStore.init(allocator, &sizes, .{ .checkpoint_every_n = 1, .max_checkpoints = 2 });
129+
defer store.deinit();
130+
131+
const act1 = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 };
132+
const act2 = [_]f32{ 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0 };
133+
const act3 = [_]f32{ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 };
134+
135+
try store.saveCheckpoint(0, &act1);
136+
try store.saveCheckpoint(1, &act2);
137+
try store.saveCheckpoint(2, &act3);
138+
139+
try std.testing.expectEqual(@as(usize, 2), store.checkpointCount());
140+
}
141+
142+
test "memory tracking" {
143+
const allocator = std.testing.allocator;
144+
const sizes = [_]usize{ 4, 8, 4 };
145+
var store = try CheckpointStore.init(allocator, &sizes, .{ .checkpoint_every_n = 1 });
146+
defer store.deinit();
147+
148+
const act = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 };
149+
try store.saveCheckpoint(0, &act);
150+
151+
const mb = store.memoryUsedMB();
152+
try std.testing.expect(mb > 0);
153+
try std.testing.expect(mb < 1.0);
154+
}
155+
156+
test "memory budget recommended checkpoint interval" {
157+
const budget = MemoryBudget{
158+
.total_bytes = 1024 * 1024 * 1024,
159+
.model_bytes = 512 * 1024 * 1024,
160+
.optimizer_bytes = 256 * 1024 * 1024,
161+
.gradient_bytes = 128 * 1024 * 1024,
162+
.activation_bytes = 64 * 1024 * 1024,
163+
};
164+
165+
const available = budget.availableForCheckpoints();
166+
try std.testing.expect(available > 0);
167+
168+
const every = budget.recommendedCheckpointEvery(12, 1024 * 1024);
169+
try std.testing.expect(every >= 1);
170+
}

0 commit comments

Comments
 (0)