新智元报道

编辑:LRST

【新智元导读】最新综述论文探讨了知识蒸馏在持续学习中的应用,重点研究如何通过模仿旧模型的输出来减缓灾难性遗忘问题。通过在多个数据集上的实验,验证了知识蒸馏在巩固记忆方面的有效性,并指出结合数据回放和使用separated softmax损失函数可进一步提升其效果。

知识蒸馏(Knowledge Distillation, KD)已逐渐成为持续学习(Continual Learning, CL)应对灾难性遗忘的常见方法。

然而,尽管KD在减轻遗忘方面取得了一定成果,关于KD在持续学习中的应用及其有效性仍然缺乏深入的探索。


图1 知识蒸馏在持续学习中的使用

目前,大多数现有的持续学习综述主要从不同方法的分类角度出发,聚焦于图像分类领域或其他应用领域,很少有综述文章专门探讨如何通过具体技术(如知识蒸馏)来缓解持续学习中的遗忘问题。

现有的研究大多关注于持续学习方法的广泛分类(如回放方法、正则化方法、参数隔离方法等),以及任务增量学习、类别增量学习、领域增量学习等不同场景的定义。

尽管这些研究为持续学习领域提供了宝贵的见解,但对于如何将知识蒸馏与持续学习结合并分析其效果,仍然缺乏系统性评估。

最近,哈尔滨工业大学和中科院自动化所的研究人员在IEEE Transactions on Neural Networks and Learning Systems(TNNLS)上发表了一篇综述论文,聚焦于知识蒸馏在持续学习中的应用。


论文链接:https://ieeexplore.ieee.org/document/10721446

主要贡献为:

  • 综合调查:首次系统地综述了基于知识蒸馏的持续学习方法,主要集中在图像分类任务中。研究人员分析了知识蒸馏在持续学习中的应用,提供了详细的分类,阐述了其在持续学习中的作用与应用场景。

  • 新的分类法:提出了一个新的分类体系,将知识蒸馏在持续学习中的应用分为三个主要范式:正则化的知识蒸馏、知识蒸馏与数据回放结合、以及知识蒸馏与特征回放结合。同时,基于蒸馏过程中使用的知识来源,将这些方法分为三个层次:logits级别、特征级别和数据级别,并从蒸馏损失的角度分析了其如何强化记忆。

  • 实验验证:在CIFAR-100、TinyImageNet和ImageNet-100等数据集上,针对十种基于知识蒸馏的持续学习方法进行了广泛的实验,系统地分析了知识蒸馏在持续学习中的作用,验证了其在减缓遗忘方面的有效性。

  • 分类偏差与改进:进一步证实,分类偏差可能会削弱知识蒸馏的效果,而采用separated softmax损失函数结合数据回放时,能够显著增强知识蒸馏在减缓遗忘方面的效果。

基于知识蒸馏的持续学习范式


图2 知识蒸馏在持续学习中的使用范式

正则化的知识蒸馏

正则化的知识蒸馏范式将知识蒸馏作为防止遗忘的核心机制,通过约束模型参数变化来保持旧任务的知识。这些方法的基本思想是通过在新任务学习时,确保模型的输出尽可能与旧任务模型的输出一致,从而避免遗忘。

例如,LwF方法通过蒸馏新任务数据在旧模型上的响应,确保新模型在学习新任务时仍能保留对旧任务的记忆[1]。这类方法的理念相对简单明了,在减轻遗忘方面的表现往往较弱,通常会导致较低的性能。

知识蒸馏与数据回放结合

知识蒸馏经常与数据回放技术相结合,以从数据和模型两个方面增强记忆保持能力。数据回放方法需要一个额外的缓存区来存储来自先前任务的样本,以近似其分布,并在持续学习过程中不断回放这些样本,以实现持久的记忆保持。

将知识蒸馏与数据回放结合,进一步增强了模型的记忆保持能力。iCaRL是第一个将知识蒸馏与数据回放相结合的方法[2]。此后,许多结合知识蒸馏和数据回放的方法将数据回放作为应对遗忘的基本技术,并探索各种蒸馏技术以进一步增强旧任务的记忆保持能力。

除了直接使用额外的内存存储来自旧任务的回放数据外,一些方法还通过生成模型[5][6]或模型反演[7][8][9]技术生成回放数据。

这些方法通常将知识蒸馏应用于生成的数据,以防止生成模型在持续学习过程中遗忘,同时也在logits或特征上使用基本的知识蒸馏技术来减缓遗忘。对于这些方法来说,除了知识蒸馏是减缓遗忘的有效手段外,生成数据的质量也在决定整体效果方面起着至关重要的作用。

由于回放数据与新任务数据之间存在严重的数据不平衡,这容易导致分类偏差,一些方法在将知识蒸馏作为记忆保持的基本机制的同时,着重解决分类偏差问题。

例如,BiC显式地通过在平衡的验证数据集上训练类别校正参数来解决分类偏差问题[10]。除了分类偏差问题外,一些其他方法将知识蒸馏与数据回放的结合作为基本的记忆保持手段,并更加关注其他问题,如回放数据的选择[13][14]以及结合基于架构的方法[15][16][17]来保持记忆。

知识蒸馏与特征回放结合

除了将知识蒸馏与数据回放结合,许多方法还将知识蒸馏与特征回放结合,旨在实现无需示例的持续学习。这一范式中的大多数方法通过在特征级别的蒸馏中使用实例特征对齐,以保持特征网络的记忆能力,并采用各种特征生成方法来生成回放特征,从而确保分类器的记忆得到保持。

例如,GFR方法通过训练生成模型来存储旧任务的特征,该生成模型在持续学习过程中生成回放特征[18]。PASS方法将类别原型定义为特征空间数据的均值,并在新类别学习期间引入高斯噪声进行数据增强,从而避免分类偏向新数据[19]。

与「知识蒸馏与数据回放结合」范式中的方法相比,这一范式不需要大量额外的内存来存储旧任务的原始样本。相反,它只需少量内存来存储每个类别的特征信息。此外,特征回放有助于减少由于回放数据和新任务数据之间的不平衡所引起的分类偏差问题。

知识来源与蒸馏损失


图3 按知识来源分类的基于知识蒸馏的持续学习方法

研究人员根据知识来源将基于知识蒸馏的持续学习方法分为三类:logits级别、特征级别和数据级别。

Logits级别蒸馏主要涉及学生模型通过模仿教师模型的最终输出logits来获取知识。这些输出通常包括两种类型:通过归一化函数(如softmax)得到的分类概率,以及原始的、未经归一化的logits。

因此,研究人员将logits级别的KD方法分为两类:概率匹配和logits匹配。概率匹配较为常见,学生模型旨在通过使用KL散度或交叉熵等损失函数,将教师模型的输出概率分布与自己的输出概率分布对齐。

相比之下,logits匹配旨在同步教师和学生模型的pre-softmax logit值,通常采用L1或L2范数等损失函数。logits匹配对蒸馏过程施加了比概率匹配更严格的约束。

特征级别蒸馏旨在传递网络特征提取阶段生成的内部表示知识。这类方法可以根据特征在网络中的位置和特征的性质分为三个子类:实例特征对齐、隐层特征对齐和关系对齐。

实例特征对齐主要针对从输入样本中提取的特征,这些特征通常被转换为一维向量。隐层特征对齐则关注特征提取器中间层特征的蒸馏,这些特征保留了与网络结构相关的空间信息。关系对齐则专注于蒸馏多个实例或原型特征之间在特征空间中的局部或全局关系动态。

数据级别蒸馏可以分为两种类型:显式数据对齐和隐式数据对齐。显式数据对齐涉及通过生成模型产生的合成数据进行蒸馏。与此不同,隐式数据对齐则专注于蒸馏数据中的潜在信息,例如注意力图或潜在编码。

图4显示了一些logits级别和特征级别蒸馏的示意图。表1展示了不同范式的持续学习方法使用的知识蒸馏级别以及相应使用的蒸馏损失。


图4 logits级别与特征级别蒸馏示意图


表1 基于知识蒸馏的持续学习方法归纳分类

实验

研究人员选择了三个在持续学习领域广泛使用的图像分类数据集:CIFAR-100、TinyImageNet和ImageNet-100,涵盖了从32×32、64×64到224×224像素的不同图像分辨率,实验聚焦于类别增量学习(CIL)。

研究人员采用了两种主要策略来模拟数据增量场景:第一种方法将数据集均匀分成多个任务,每个任务包含相等数量的类别,进行持续学习;第二种方法先对一部分类别进行初步的基础训练,然后使用剩余类别进行持续学习。

为了清晰描述这些场景,研究人员采用了[22]中的符号表示,选择了十个基于知识蒸馏的持续学习方法进行实验:LwF [1]、LwM [3]、IL2A [20]、PASS [19]、PRAKA [21]、iCaRL [2]、EEIL [4]、BiC [10]、LUCIR [11]和SS-IL [12]。

针对数据集的实验


表2 针对不同数据集的实验结果

实验结果如表2所示。针对所有数据集,在没有基础训练的10任务场景中,BiC方法在所有数据集上表现最佳。在有基础训练的11任务场景中,PRAKA在所有数据集上表现突出。在有无基础训练的两种场景中,「知识蒸馏与数据回放结合」范式的方法普遍表现较好。

在没有基础训练的场景中,「知识蒸馏与特征回放结合」范式的方法略逊于数据回放范式。然而,在有基础训练的场景中,特征回放方法的表现显著提升,PRAKA在所有数据集上超过了数据回放范式的方法。

相比之下,「正则化的知识蒸馏」范式方法表现较差,且LwF和LwM在有基础训练的场景中表现低于没有基础训练的情况,其他方法通常在有基础训练的场景中表现更好。

针对知识蒸馏效果的实验


表3 针对知识蒸馏效果的实验结果

本实验通过去除知识蒸馏损失函数,探讨了知识蒸馏在持续学习抗遗忘中起的作用,实验结果如表3所示。所有方法中,除了LwM采用了两种蒸馏损失外,大多数方法都使用了单一的蒸馏损失。对于LwM,仅去除其注意力图蒸馏损失,保留了logits级蒸馏。

知识蒸馏在「正则化的知识蒸馏」以及「知识蒸馏与特征回放结合」范式的方法中起到了关键作用。在有无基础训练的场景中,去除知识蒸馏后,性能明显下降。

然而,在「知识蒸馏与数据回放结合」范式下的方法中,情况有所不同。结果显示,在有基础训练的场景中,知识蒸馏显著有助于减缓遗忘,一旦去除蒸馏,所有方法的性能均有所下降。

在没有基础训练的场景中,EEIL、BiC和SS-IL在去除KD后表现下降。相反,iCaRL和LUCIR的性能有所提升,iCaRL的提升尤为明显,LUCIR的提升较小。

针对蒸馏损失的实验


表4 针对蒸馏损失的实验结果

为了评估不同蒸馏损失在减缓遗忘方面的有效性,研究人员进行了独立的知识蒸馏损失评估,未使用任何其他防止遗忘的技术。

研究人员评估了交叉熵、KL散度、logits级的L2距离损失,以及基于L2距离和余弦相似度的实例特征对齐损失,实验结果如表4所示。

在持续学习过程中,分类头的训练采用了LwF中的方式,即只训练当前任务的分类头,而之前任务的分类头仅参与蒸馏,因为如果没有来自旧任务的数据使用全局分类损失,会导致严重的分类偏差问题,并显著降低知识蒸馏的效果。

结果表明不同知识蒸馏损失均有减缓遗忘的能力。其中,logits级的知识蒸馏损失在减缓遗忘方面表现明显优于特征级的知识蒸馏损失。在所有logits级知识蒸馏损失中,L2距离损失具有更强的约束能力,较KL散度表现更好,优于交叉熵蒸馏损失的抗遗忘效果。

对于特征级的知识蒸馏损失,包含更多语义信息的余弦相似度损失,在减缓遗忘方面优于L2距离损失。

针对知识蒸馏与数据回放的实验


表5 针对知识蒸馏与数据回放的实验结果

为了进一步了解知识蒸馏在与数据回放结合时的作用,并探索不同知识蒸馏损失的效果,研究人员将几种知识蒸馏损失与基本的回放范式进行比较,数据回放使用herding方式来缓存回放数据,每个类别保存20个样本。

实验结果(表5 -a)表明将知识蒸馏与数据回放结合时,logits级的知识蒸馏损失始终会导致性能下降,这一负面影响在没有基础训练的情况下尤为明显,logits级知识蒸馏会显著降低性能。

在基础训练的情况下,特征级知识蒸馏的正面效果稍微更明显,而余弦相似度损失在保持已学习特征方面表现优越。然而,在没有基础训练的情况下,余弦相似度损失在保持记忆方面的效果不如L2损失。

研究人员假设这种现象可能是由于分类头引入的分类偏差所致。为了验证这一假设,采用SS-IL的方法中使用Separated softmax损失来学习分类头,即使用回放数据共同训练所有旧任务的分类头,而新任务数据则专门用于训练新任务的分类头。实验结果(表5 -b)表明分类偏差确实会影响KD的效果。

令人惊讶的是,即使没有使用KD,使用Separated softmax的数据回放也比全局分类的回放表现更好。

未来展望

论文从三个不同的视角探讨了基于知识蒸馏的持续学习的未来发展趋势。

高质量知识的知识蒸馏:尽管知识蒸馏在减缓持续学习中的灾难性遗忘方面已经展现出潜力,但仍有较大的提升空间。有效的知识传递依赖于蒸馏知识的质量。高质量的知识传递对于提升持续学习中的知识蒸馏效果至关重要。

随着对知识质量的要求越来越高,如何更好地提取和传递高质量知识,将是未来持续学习研究中的一个重要方向。

针对特定任务的知识蒸馏:持续学习的研究已从最初专注于分类任务,扩展到包括其他多种任务,例如计算机视觉中的目标检测、语义分割,以及自然语言处理中的语言学习、机器翻译、意图识别和命名实体识别等。

这表明,知识蒸馏不仅能够应用于传统的分类任务,还需要针对具体任务进行定制化设计,以提高在不同应用场景中的表现。

更好的教师模型:近年来,基于预训练模型(PTM)和大型语言模型(LLM)的持续学习受到了越来越多的关注。知识蒸馏作为一种自然适用于减少遗忘的技术,对于PTM和LLM的持续学习尤为重要。

这是因为知识蒸馏遵循教师-学生框架,而PTM和LLM已经具备了丰富的知识,可以作为「具有丰富的经验教师」开始持续学习,从而更有效地指导学生模型的学习。未来,如何通过更强大的教师模型来优化知识蒸馏的效果,将是持续学习中值得深入研究的方向。

参考资料:

1.Z. Li and D. Hoiem, “Learning without forgetting,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 40, no. 12, pp. 2935–2947, 2017.

2.S.-A. Rebuffi, A. Kolesnikov, G. Sperl, and C. H. Lampert, “icarl: Incremental classifier and representation learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 2001–2010, 2017.

3.P. Dhar, R. V. Singh, K.-C. Peng, Z. Wu, and R. Chellappa, “Learning without memorizing,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 5138–5146, 2019.

4.F. M. Castro, M. J. Marı́n-Jiménez, N. Guil, C. Schmid, and K. Alahari, “End-to-end incremental learning,” in Eur. Conf. Comput. Vis., pp. 233–248, 2018.

5.C. Wu, L. Herranz, X. Liu, J. Van De Weijer, B. Raducanu, et al., “Memory replay gans: Learning to generate new categories without forgetting,” in Adv. Neural Inform. Process. Syst., vol. 31, pp. 5966–5976, 2018.

6.W. Hu, Z. Lin, B. Liu, C. Tao, Z. T. Tao, D. Zhao, J. Ma, and R. Yan, “Overcoming catastrophic forgetting for continual learning via model adaptation,” in Int. Conf. Learn. Represent., 2019.

7.J. Smith, Y.-C. Hsu, J. Balloch, Y. Shen, H. Jin, and Z. Kira, “Always be dreaming: A new approach for data-free class-incremental learning,”in Int. Conf. Comput. Vis., pp. 9374–9384, 2021.

8.Q. Gao, C. Zhao, B. Ghanem, and J. Zhang, “R-dfcil: Relation-guided representation learning for data-free class incremental learning,” in Eur. Conf. Comput. Vis., pp. 423–439, Springer, 2022.

9.M. PourKeshavarzi, G. Zhao, and M. Sabokrou, “Looking back on learned experiences for class/task incremental learning,” in Int. Conf. Learn. Represent., 2021.

10.Y. Wu, Y. Chen, L. Wang, Y. Ye, Z. Liu, Y. Guo, and Y. Fu, “Large scale incremental learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 374–382, 2019.

11.S. Hou, X. Pan, C. C. Loy, Z. Wang, and D. Lin, “Learning a unified classifier incrementally via rebalancing,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 831–839, 2019.

12.H. Ahn, J. Kwak, S. Lim, H. Bang, H. Kim, and T. Moon, “Ss-il: Separated softmax for incremental learning,” in Int. Conf. Comput. Vis., pp. 844–853, 2021.

13.Y. Liu, Y. Su, A.-A. Liu, B. Schiele, and Q. Sun, “Mnemonics training: Multi-class incremental learning without forgetting,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 12245–12254, 2020.

14.R. Tiwari, K. Killamsetty, R. Iyer, and P. Shenoy, “Gcr: Gradient coreset based replay buffer selection for continual learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 99–108, 2022.

15.J. Rajasegaran, M. Hayat, S. H. Khan, F. S. Khan, and L. Shao,“Random path selection for continual learning,” in Adv. Neural Inform. Process. Syst., vol. 32, pp. 12648–12658, 2019.

16.A. Douillard, A. Ramé, G. Couairon, and M. Cord, “Dytox: Transformers for continual learning with dynamic token expansion,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 9285–9295, 2022.

17.F.-Y. Wang, D.-W. Zhou, H.-J. Ye, and D.-C. Zhan, “Foster: Feature boosting and compression for class-incremental learning,” in Eur. Conf. Comput. Vis., pp. 398–414, Springer, 2022.

18.X. Liu, C. Wu, M. Menta, L. Herranz, B. Raducanu, A. D. Bagdanov, S. Jui, and J. v. de Weijer, “Generative feature replay for class-incremental learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 226–227, 2020.

19.F. Zhu, X.-Y. Zhang, C. Wang, F. Yin, and C.-L. Liu, “Prototype augmentation and self-supervision for incremental learning,” in IEEE Conf. Comput. Vis. Pattern Recog., pp. 5871–5880, 2021.

20.F. Zhu, Z. Cheng, X.-Y. Zhang, and C.-l. Liu, “Class-incremental learning via dual augmentation,” in Adv. Neural Inform. Process. Syst., vol. 34, pp. 14306–14318, 2021.

21.W. Shi and M. Ye, “Prototype reminiscence and augmented asymmetric knowledge aggregation for non-exemplar class-incremental learning,”in Int. Conf. Comput. Vis., pp. 1772–1781, 2023.

22.M. Masana, X. Liu, B. Twardowski, M. Menta, A. D. Bagdanov, and J. Van De Weijer, “Class-incremental learning: survey and performance evaluation on image classification,” IEEE Trans. Pattern Anal. Mach. Intell., vol. 45, no. 5, pp. 5513–5533, 2022.

ad1 webp
ad2 webp
ad1 webp
ad2 webp