训练第一个强化学习模型

张开发
2026/5/22 1:27:02 15 分钟阅读
训练第一个强化学习模型
认识stable_baseline3stable_baseline3提供了许多模型如下列表名称动作空间建议应用场景核心优势PPO连续 离散全能选手如机器人走动、金融交易、游戏 AI极其稳定对超参数不敏感支持大规模并行训练。DQN仅离散经典游戏Atari、开关控制、迷宫寻路理解简单在离散控制领域非常经典且有效。SAC仅连续复杂物理模拟、机械臂抓取、自动驾驶探索效率极高能自动寻找最优路径且不轻易陷入局部最优。TD3仅连续工业控制、无人机飞行、精密动作针对 DDPG 的缺陷做了改进训练过程比 SAC 更平滑。A2C连续 离散简单逻辑测试、快速原型验证结构简单虽然不如 PPO 稳定但在特定并行环境下速度极快。在声明模型中可以设置多种参数这里列出常用的目前不需要搞懂都有什么作用后面有文章会详细讲解训练参数learning_rate学习率gamma折扣因子batch_size更新模型使用数据量verbose打印信息模式。0-静默模式1-信息模式2-调试模式device指定训练设备cuda使用显卡cpu使用cpu模型规则MlpPolicy多层感知机。适用于状态是数值场景传感器等CnnPolicy卷积神经网络。适用于状态是图像场景游戏等训练第一个强化学习模型案例案例描述训练一个gymnasium默认提供的游戏环境平衡杆游戏。import gymnasium as gym from stable_baselines3 import PPO env gym.make(CartPole-v1) model PPO(MlpPolicy, env, verbose1, devicecuda) print(开始训练...) model.learn(total_timesteps10000) print(正在保存模型...) model.save(ppo_cartpole) print(正在读取模型...) env gym.make(CartPole-v1, render_modehuman) loaded_model PPO.load(ppo_cartpole, envenv) print(训练结束开始演示...) obs, _ env.reset() for i in range(1000): action, _states loaded_model.predict(obs, deterministicTrue) obs, reward, terminated, truncated, info env.step(action) if terminated or truncated: obs, _ env.reset() env.close()代码解释代码流程如下初始化环境模型-训练模型-保存模型-加载模型-模型预测初始化环境模型初始化模型以及游戏的环境env gym.make(CartPole-v1) model PPO(MlpPolicy, env, verbose1, devicecuda) env gym.make(CartPole-v1, render_modehuman)gym中的make方法利用默认的游戏环境CartPole-v1是游戏名下面有一个render_modehuman参数用于标识是否展示画面。训练时展示画面会降低训练的速度一般在预测时才使用训练模型model.learn(total_timesteps10000)total_timesteps训练10000次保存模型model.save(ppo_cartpole)ppo_cartpole为保存模型的名字这里是保存在当前文件夹中。加载模型loaded_model PPO.load(ppo_cartpole, envenv)第一个参数刚刚保存的模型路径第二个参数训练的环境模型预测obs, _ env.reset() for i in range(1000): action, _states loaded_model.predict(obs, deterministicTrue) obs, reward, terminated, truncated, info env.step(action) if terminated or truncated: obs, _ env.reset()env.reset()重置环境返回初始观测值obs和info(这里没用到)

更多文章