选自oxen.ai

作者:Greg Schoeninger

编译:陈陈、泽南

RTX 3080 移动版能训练哪种大模型?本文为那些 GPU 资源有限时使用 GRPO 训练的开发者提供了宝贵的指导。

自 DeepSeek-R1 发布以来,群组相对策略优化(GRPO)因其有效性和易于训练而成为大型语言模型强化学习的热门话题。R1 论文展示了如何使用 GRPO 从遵循 LLM(DeepSeek-v3)的基本指令转变为推理模型(DeepSeek-R1)。

GRPO 是一种在线学习算法(online learning algorithm),它通过使用训练过程中由训练模型自身生成的数据来进行迭代改进。GRPO 的目标是最大化生成补全(completions)的优势函数(advantage),同时确保模型保持在参考策略(reference policy)附近。



本文的目的是帮你节省一些时间,让你根据硬件预算选择合适的模型大小。在开始微调时,你必须做出的重要决定是选择模型大小,以及你是执行完全微调还是参数高效微调(PEFT)。

文章作者来自 AI 公司 Oxen.ai 的 CEO Greg Schoeninger。



原文链接:https://www.oxen.ai/blog/grpo-vram-requirements-for-the-gpu-poor

作者表示,他发现 trl 库中已经有一个易于使用的 GRPO 实现,便立刻开始了训练,使用的硬件是配备了 16GB 显存的 Nvidia GeForce RTX 3080 的小型笔记本电脑。正如大家可能遇到的问题,作者发现示例代码中的参数设置导致了一个巨大的显存不足(OOM,out of memory )错误。

  1. torch
  2. OutOfMemoryError
  3. CUDA
  4. out
  5. of memory
  6. Tried
  7. to allocate
  8. 1.90
  9. GiB
  10. GPU
  11. 0
  12. has a total capacity of
  13. GiB
  14. of which
  15. 1.28
  16. GiB
  17. is
  18. free
  19. Including
  20. non
  21. PyTorch
  22. memory
  23. this
  24. process has
  25. GiB
  26. memory
  27. in
  28. use
  29. Of
  30. the allocated memory
  31. GiB
  32. is
  33. allocated
  34. by
  35. PyTorch
  36. and
  37. 2.41
  38. GiB
  39. is
  40. reserved
  41. by
  42. PyTorch
  43. but unallocated
  44. If
  45. reserved but unallocated memory
  46. is
  47. large
  48. try
  49. setting PYTORCH_CUDA_ALLOC_CONF
  50. expandable_segments
  51. True
  52. to avoid fragmentation
  53. See
  54. documentation
  55. for
  56. Memory
  57. Management
  58. //pytorch.org/docs/stable/notes/cuda.html#environment-variables)

实际使用情况

作者表示,他们进行了一系列实验,以确定训练各种大小的模型所需的显存(VRAM)要求。参数数量从 5 亿到 140 亿不等,他们比较了权重的完全微调与参数高效微调(使用 LoRA),所有训练运行都在英伟达 H100 上完成,因此这里的 OOM 意味着 >80GB 的 VRAM。



在表格中,你可以找到 GSM8K 数据集上训练的前 100 步中的峰值内存使用情况。用于实验的模型是:



所有实验均使用 Shadeform 的 GPU 市场完成,因此每次实验只需要花费几美元 H100。

实验结果表明,内存需求随着模型大小和训练方式的不同而显著变化。例如,全参数微调比 PEFT 需要更多的内存。

为什么 GRPO 对内存需求较高

这要从 GRPO 的原理说起,这是它的流程图。



GRPO 对内存需求较高的原因在于,其内部涉及多个模型,并且在训练数据中每个查询会产生多个输出。上图中的策略模型、参考模型和奖励模型各自都是一个需要进行推理的 LLM。(尽管从技术上讲,奖励模型可能不需要参数化,可以只是一个 Python 函数或正则表达式,但不影响 GRPO 对内存的高需求。)

为什么 8-Bit 优化和梯度检查点有助于减少内存占用?

通常来讲,训练一个大型语言模型需要在内存中存储三种主要类型的信息:模型参数、模型学习所需的梯度、优化器的跟踪数据。

对上述内容我们可以这样理解:如果模型的参数占用了 X 的空间,那么梯度也会占用大约相同的空间。然后,像 AdamW 这样的优化器需要更多的空间,因为它们就像一个记录员,跟踪最近的更新历史,以便更好地决定未来的优化。

为了减轻这种内存负担,通常采用两种技术:

  • 首先,可以使用像 AdamW 这样的 8-bit 优化器版本,它们能更高效地存储跟踪数据,同时仍保持良好的性能 —— 类似于压缩照片可以节省空间,同时保留大部分图像质量;
  • 其次,使用梯度检查点技术,这就像在训练过程中拍摄快照,而不是记录所有内容。虽然这会使训练速度减慢约 20-30%,但它显著减少了内存使用。

结合这些技术,即使对 GPU 资源有限的人来说,也能够训练更大的模型。

代码示例

像 trl 这样的库已经开始支持 GRPO,使得微调由 transformers 构成的 LLM 变得非常简单。代码也非常简洁,只需将训练器替换为 GRPOTrainer 并定义一些奖励即可。GRPO 的最小代码量大约只有 99 行,如果你使用的是像 meta-llama/Llama-3.2-1B-Instruct 这样的小型模型和像 openai/GSM8K 这样的数据集,可以非常快速地启动。

trl 项目地址:https://github.com/huggingface/trl?ref=ghost.oxen.ai

  1. import
  2. torch
  3. from
  4. datasets
  5. import
  6. load_dataset
  7. Dataset
  8. from
  9. transformers
  10. import
  11. AutoTokenizer
  12. AutoModelForCausalLM
  13. from
  14. trl
  15. import
  16. GRPOConfig
  17. GRPOTrainer
  18. import
  19. re
  20. SYSTEM_PROMPT
  21. Respond in the following format:
  22. def
  23. extract_hash_answer
  24. text
  25. str
  26. str
  27. None
  28. if
  29. "####"
  30. not
  31. in
  32. text
  33. return
  34. None
  35. return
  36. text
  37. split
  38. "####"
  39. 1
  40. strip
  41. def
  42. get_gsm8k_questions
  43. split
  44. "train"
  45. Dataset
  46. data
  47. load_dataset
  48. 'openai/gsm8k'
  49. 'main'
  50. split
  51. data
  52. data
  53. map
  54. lambda
  55. 'prompt'
  56. 'role'
  57. 'system'
  58. 'content'
  59. SYSTEM_PROMPT
  60. },
  61. 'role'
  62. 'user'
  63. 'content'
  64. 'question'
  65. ],
  66. 'answer'
  67. extract_hash_answer
  68. 'answer'
  69. return
  70. data
  71. def
  72. extract_xml_answer
  73. text
  74. str
  75. str
  76. answer
  77. text
  78. split
  79. 1
  80. answer
  81. answer
  82. split
  83. ""
  84. 0
  85. return
  86. answer
  87. strip
  88. def
  89. format_reward_func
  90. completions
  91. kwargs
  92. list
  93. float
  94. """Reward function that checks if the completion has a specific format."""
  95. pattern
  96. r
  97. "^\n\n$"
  98. \n.*?\n
  99. \n.*?\n
  100. responses
  101. completion
  102. 0
  103. "content"
  104. for
  105. completion
  106. in
  107. completions
  108. matches
  109. re
  110. match
  111. pattern
  112. r
  113. for
  114. r
  115. in
  116. responses
  117. return
  118. 0.5
  119. if
  120. match
  121. else
  122. 0.0
  123. for
  124. match
  125. in
  126. matches
  127. def
  128. accuracy_reward_func
  129. prompts
  130. completions
  131. answer
  132. kwargs
  133. list
  134. float
  135. """Reward function that extracts the answer from the xml tags and compares it to the correct answer."""
  136. responses
  137. completion
  138. 0
  139. 'content'
  140. for
  141. completion
  142. in
  143. completions
  144. extracted_responses
  145. extract_xml_answer
  146. r
  147. for
  148. r
  149. in
  150. responses
  151. return
  152. 2.0
  153. if
  154. r
  155. a
  156. else
  157. 0.0
  158. for
  159. r
  160. a
  161. in
  162. zip
  163. extracted_responses
  164. answer
  165. def
  166. main
  167. dataset
  168. get_gsm8k_questions
  169. model_name
  170. "meta-llama/Llama-3.2-1B-Instruct"
  171. model
  172. AutoModelForCausalLM
  173. from_pretrained
  174. model_name
  175. torch_dtype
  176. torch
  177. bfloat16
  178. attn_implementation
  179. "flash_attention_2"
  180. device_map
  181. None
  182. to
  183. "cuda"
  184. tokenizer
  185. AutoTokenizer
  186. from_pretrained
  187. model_name
  188. tokenizer
  189. pad_token
  190. tokenizer
  191. eos_token
  192. training_args
  193. GRPOConfig
  194. output_dir
  195. "output"
  196. learning_rate
  197. 5e-6
  198. adam_beta1
  199. 0.9
  200. adam_beta2
  201. 0.99
  202. weight_decay
  203. 0.1
  204. warmup_ratio
  205. 0.1
  206. lr_scheduler_type
  207. 'cosine'
  208. logging_steps
  209. 1
  210. bf16
  211. True
  212. per_device_train_batch_size
  213. 1
  214. gradient_accumulation_steps
  215. 4
  216. num_generations
  217. 4
  218. max_prompt_length
  219. 256
  220. max_completion_length
  221. 786
  222. num_train_epochs
  223. 1
  224. save_steps
  225. 100
  226. save_total_limit
  227. 1
  228. max_grad_norm
  229. 0.1
  230. log_on_each_node
  231. False
  232. trainer
  233. GRPOTrainer
  234. model
  235. model
  236. processing_class
  237. tokenizer
  238. reward_funcs
  239. format_reward_func
  240. accuracy_reward_func
  241. ],
  242. args
  243. training_args
  244. train_dataset
  245. dataset
  246. trainer
  247. train
  248. if
  249. __name__
  250. "__main__"
  251. main

Num Generations 有什么用

Num Generations 是一个超参数,它决定了我们将在训练数据中对每个查询采样多少个补全。然而,这会显著增加 VRAM 的消耗。



目前有一个开放的 GitHub 问题,可能会帮助解决内存瓶颈问题,可以参考如下链接

地址:https://github.com/huggingface/trl/issues/2709?ref=ghost.oxen.ai

对于 num_completions=8,16,64 (DeepSeekMath 论文使用的 64),作者表示,不用再次计算上述所有值,而是使用了 1B 参数模型进行了测试,以显示内存增长。不过,作者还是建议大家在内存瓶颈得到修复之前使用 num_generations=4,也能获得不错的性能。



影响 VRAM 的一些因素

要对所有影响显存(VRAM)使用的因素进行全面的超参数验证,需要进行大量的实验。简单起见,这里只指出了需要注意的设置,以及实验中使用的具体数值。

  • batch_size=1,由于 GRPO 为每个查询生成多个响应,batch size 会迅速失控。
  • gradient_accumulation_steps=4,优化器是另一个占用大量 VRAM 的地方。此参数决定了我们将存储的梯度以帮助优化器进行其「爬山」过程。
  • num_completions=4,DeepSeekMath 论文中使用了 64。这完全超出了有些人的计算预算。
  • max_prompt_length=256,如果你想训练模型拥有更大上下文的推理能力,将不得不增加 VRAM。GSM8K 的提示相对较小,适合此测试。
  • max_completion_length=786,同样,由于计算注意力的内存有限,推理链在这里受到限制。上下文或生成的 token 越多,需要的内存就越大。
  • LoRA target_modules=["q_proj", "k_proj", "o_proj", "up_proj", "down_proj"] 在这方面可以尝试几种不同的迭代。target_modules="all-linear" 是一种流行的方式,可以从你的 LoRA 中挤出最多的性能(就准确性而言)。

对 VRAM 使用的粗略估算

如果你正在使用 FP16 精度进行训练,以下是一些简单的估算方法,可以帮助你了解内存主要用在了哪些地方:

  • 模型参数:每个参数占用 2 字节。
  • 参考模型参数:每个参数占用 2 字节。
  • 梯度:每个参数占用 2 字节。
  • 优化器状态:每个参数占用 8 字节。
  • 8 位优化器:每个参数占用 4 字节。
  • PEFT:有助于减少梯度的显存占用。

最后是关于准确率的。作者完成了一个 10 亿参数的 Llama 3.2 模型的完整训练。在应用 GRPO 之前,该模型在保留测试集上达到了约 19% 的准确率,而在经过一个训练周期后,模型的准确率飙升至约 40.5%。虽然这离 SOTA 水平还差得很远,但这展示了 GRPO 的强大潜力。

ad1 webp
ad2 webp
ad1 webp
ad2 webp