最近看了几篇 Semi-Supervised Learning 的文章,感觉要达到 state-of-the-art 的话就是要把几个比较有效的技巧比较好地结合到一起,所以写一篇文章分别介绍一下这些技巧。文章总体的结构内容基于 MixMatch: A Holistic Approach to Semi-Supervised Learning (NerIPS 2019),及一些适当的延伸。

我们主要讨论就是 Transductive Learning 的场景,即 在训练时 \((X_{label}, Y_{label}, X_{unlabel})\),来预测 \(Y_{label}\) 。Inductive Learning 的场景就更像是 supervised learning,即在训练时只用\((X_{label}, Y_{label})\) ,在预测的时候才用到\(X_{unlabel}\) 来预测 \(Y_{unlabel}\)。半监督学习 的主要想法就是,认为需要预测的内容的特征(例如 分布信息)可以帮助学到一个更好的 classifier。

Entropy Minimization

在 Semi-Supervised Learning 中,一个常见的 Assumption 是,分类器的 decision boundary 不应该穿过数据分布中的 high-density 区域。一种直接的方式就是增加一个 loss term 来直接显式地降低 entropy。

Sharpening

对于unlabeled的数据,可以用 Sharpening 的技巧来隐式地降低Entropy。 \[ \text {Sharpen}\left(p_{i}, T\right):=p_{i}^{\frac{1}{T}} / \sum_{j=1}^{C} p_{j}^{\frac{1}{T}} \] 其中,\(p\) 是 categorical distribution,\(T\) 是一个 hyperparameter。当\(T \rightarrow 0\) 时,\(\text {Sharpen}\left(p_{i}, T\right)\) 就会接近于 Dirac Distribution (one-hot)。

由于 \(\text {Sharpen}\left(p_{i}, T\right)\) 会被用作模型对于 unlabeled 数据的预测的 target,所以较低的 \(T\) 会使得模型更倾向于输出低 entropy 的结果。

Consistency Regularization

简单来说就是把 data augmentation 的技巧运用到 Semi-Supervised Learning 中来。基于的基本想法就是对于一个sample,即使它被增强过,分类器输出的 class distribution 也应该是不变的。

Regularization with stochastic perturbations

这是最简单直接的方式,增加一个loss来控制不同的 stochastic transformations 对输出带来的影响: \[ \| p_{\text {model }}(y | \text { Augment }(x) ; \theta)-p_{\text {model }}(y | \text { Augment }(x) ; \theta) \|_{2}^{2} \]

Label Guessing

对于一个 unlabeled 的 sample,可以先用模型输出来为这个 sample 猜出一个 label 出来。然后这个猜出来的 label 可以用在 unsupervised loss 中(即 \(\mathcal{L}_{\mathcal{U}}=\frac{1}{L\left|\mathcal{U}^{\prime}\right|} \sum_{u, q \in \mathcal{U}^{\prime}}\|q - p_{model}(y | u ; \theta)\|_{2}^{2}\)\(u\) 是 unlabeled item,\(q\) 是 guessing label (distribution))。

实践中可以用 \(K\) 个 augmented 的 \(u_b\) 的模型预测的平均作为 guessing label 来增加稳定性: \[ \mathcal{L}_{\mathcal{U}}=\frac{1}{\left|\mathcal{U}^{\prime}\right|} \sum_{u, q \in \mathcal{U}^{\prime}}\|q-p_{model}(y | u ; \theta)\|_{2}^{2}. \]

也有文章有讨论了 mean-teacher (即 averaging model weights instead of predictions),认为这种方式会更好。

Exponential Moving Average (EMA)

即是对模型输出的 predictions 进行 EMA 的计算来作为新的 target。 \[ Z = \alpha Z + (1-\alpha)z \]

\[ \tilde{z} = \frac{Z}{1-\alpha^t} \]

\(Z\) 被初始为 \(\mathbf{0}_{N\times C}\)\(z\) 是每个 epoch 模型对于每个 sample 输出,\(t\) 是epoch,\(\tilde{z}\)是经过 bias correction 的 target vector。

Virtual Adversarial Training (VAT)

Virtual adversarial loss is defined as the robustness of the conditional label distribution around each input data point against local perturbation.

首先 adversarial training 就是在输入中加入一个扰动 \(r_{adv}\),使得模型的输出发生尽可能大程度的变化,即: \[ L_{\mathrm{adv}}\left(x_{l}, \theta\right):=D\left[q\left(y | x_{l}\right), p\left(y | x_{l}+r_{\mathrm{adv}}, \theta\right)\right] \] 其中 \[ r_{\mathrm{adv}}:=\underset{r ;\|r\| \leq \epsilon}{\arg \max } D\left[q\left(y | x_{l}\right), p\left(y | x_{l}+r, \theta\right)\right] \] \(D[q, p]\)是描述两个分布之间的 divergence 的非负函数,例如 Cross Entropy。

至于 Virtual Adversarial Training,就是用当前模型的输出 \(p(y|x, \hat\theta)\) 来近似数据label的真实概率分布\(q(y|x)\)。这样就定义了一种 virtual adversarial perturbation. 这样就很容易写出对应的 loss item: \[ \operatorname{LDS}\left(x_{*}, \theta\right):=D\left[p\left(y | x_{*}, \hat{\theta}\right), p\left(y | x_{*}+r_{\text {vadv }}, \theta\right)\right] \] 其中 \[ r_{\mathrm{vadv}}:=\underset{r ;\|r\|_{2} \leq \epsilon}{\arg \max } D\left[p\left(y | x_{*}, \hat{\theta}\right), p\left(y | x_{*}+r\right)\right] \] \(x_*\) 包含了 \(x_{label}, x_{unlabel}\).

Generic Regularization

有一些 Regularizaion 的方法就是给模型加上一些 constraint 使它避免“记住”训练数据,从而更好地 generalize 到别的 unseen data。最为常见的一种做法就是给模型参数加上一个 \(L_2\)-weight-decay。

Mixup

简单来说,mixup 就是构造这样的虚拟training samples: \[ \tilde{x} = \lambda x_i + (1-\lambda)x_j, \\ \tilde{y} = \lambda y_i + (1-\lambda)y_j. \\ \] 其中 \(x_i, x_j\) 是原始的输入vector,\(y_i, y_j\) 是 one-hot label encoding,\(\lambda \in [0, 1]\)

把 mixup 应用到 Semi-Supervised Learning 的话,可以把 labeled data 和 unlabeled data 一起 mixup 起来,其中 unlabeled sample 的 \(y\) 可以换成 guessing label \(q\)。另外还可以加一个小trick就是让\(\lambda \in [0.5, 1]\),使得虚拟的 sample 可以更靠近真实的数据。这种情况下mixup其实应该属于 Consistency Regularization.

Why \(L_2\) loss

用 cross entropy 的时候,需要先用 Softmax 计算出概率,但是如果所有输出值都加上一个常数的话,softmax 的结果是不变的。所以为了让两个向量尽可能相等,\(L_2\) 是更为严格的限制。

Warmup of \(\lambda\)

整体的 loss function 是由监督loss和unlabeled data的loss组合起来的 \(L = L_X + \lambda L_U\),所以中间会有一个\(\lambda\) 来控制两者的比例。相比于直接将\(\lambda\)设置为一个常数,一些实验中发现将它从0慢慢 linear warmup 到它的 final value 可以提升最后的分类 accuracy。

References

  1. MixMatch: A Holistic Approach to Semi-Supervised Learning
  2. Semi-supervised Learning by Entropy Minimization
  3. Temporal Ensembling for Semi-Supervised Learning
  4. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results
  5. mixup: Beyond Empirical Risk Minimization