-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
70 lines (62 loc) · 1.83 KB
/
train.py
File metadata and controls
70 lines (62 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
from pathlib import Path
from game_environment import GameEnvironment
from stable_baselines3 import PPO
from custom_features_extractor import CustomFeaturesExtractor
# Constants
DISPLAY_SHAPE = (480, 480)
FPS = 24
TOTAL_TIMESTEPS = 100_000
SEED = 13
NUM_EXPERT_EPISODES = 128
# Initialize environments
env = GameEnvironment(DISPLAY_SHAPE, 1.0 / float(FPS))
env.reset()
# Paths
model_path = Path("ppo_bouncing_balls_latest.zip")
# Create or load model
if model_path.is_file():
model = PPO.load(model_path, env=env)
print("Loaded existing model.")
else:
policy_kwargs = dict(
features_extractor_class=CustomFeaturesExtractor,
features_extractor_kwargs=dict(features_dim=128)
)
model = PPO(
"MultiInputPolicy",
env,
policy_kwargs=policy_kwargs,
verbose=1,
learning_rate=1e-5,
n_steps=512,
batch_size=128,
gamma=0.999,
clip_range=0.05,
ent_coef=0.01
)
# Train with PPO (fine-tuning)
print("\nStarting PPO fine-tuning...")
try:
# Phase 1: Train only value function
print("Phase 1: Train value function only...")
for param in model.policy.action_net.parameters():
param.requires_grad = False
for param in model.policy.mlp_extractor.policy_net.parameters():
param.requires_grad = False
model.vf_coef = 1.0
model.learn(total_timesteps=TOTAL_TIMESTEPS // 5)
# Phase 2: Unfreeze and train everything
print("Phase 2: Training full policy...")
for param in model.policy.parameters():
param.requires_grad = True
model.vf_coef = 0.5
model.learn(total_timesteps=TOTAL_TIMESTEPS)
except Exception as e:
print(f"Training failed: {e}")
import traceback
traceback.print_exc()
finally:
model.save(model_path)
print(f"Model saved to {model_path}")
env.close()