论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》

论文信息

论文标题:PairNorm: Tackling Oversmoothing in GNNs
论文作者:Lingxiao Zhao, Leman Akoglu
论文来源:2020,ICLR
论文地址:download 
论文代码:download 

1 Introduction

  GNNs 的表现随着层数的增加而有所下降,一定程度上归结于 over-smoothing 问题,重复图卷积操作会使得节点表示最终变得不可区分。为缓解过平滑问题提出了 PairNorm, 一种归一化方法。

  比较可惜的时,该论文在使用了 2022 年的 "Mask" 策略,可惜了实验做的不咋好。为什么失败,见文末。太可惜了...

2 Understanding oversmoothing

Definition

    $tilde{mathbf{A}}_{mathrm{sym}}=tilde{mathbf{D}}^{-1 / 2} tilde{mathbf{A}} tilde{mathbf{D}}^{-1 / 2}$

    $tilde{mathbf{A}}_{mathrm{rw}}=tilde{mathbf{D}}^{-1} tilde{mathbf{A}}$

2.1 The oversmoothing problem

2.1.1 Oversmoothing

  GNN 性能下降的原因:

    • 参数数量的增加;
    • 梯度消失导致训练困难;
    • 图卷积而造成的过平滑;

  过平滑的考虑方法如下:当多次使用拉普拉斯平滑导致节点特征收敛到一个平稳点。假设 $mathbf{x}_{cdot j} in mathbb{R}^{n}$ 表示 $mathbf{X}$ 的第 $j $ 列,对于任意 $mathbf{x}_{cdot j} in mathbb{R}^{n}$:

    $begin{array}{l}underset{k rightarrow infty}{text{lim}} quad  tilde{mathbf{A}}_{mathrm{sym}}^{k} mathbf{x}_{cdot j} =boldsymbol{pi}_{j}\ text { and } quad frac{boldsymbol{pi}_{j}}{left|boldsymbol{pi}_{j}right|_{1}}=boldsymbol{pi}end{array}$

  其中,标准化解 $pi in mathbb{R}^{n}$ 满足 $boldsymbol{pi}_{i}=frac{sqrt{operatorname{deg}_{i}}}{sum_{i} sqrt{operatorname{deg}_{i}}}  text{ for all }  i in[n]$。

  Note:$boldsymbol{pi}$ 不依赖于节点特征矩阵,而是一个单纯依靠图结构度的函数。

2.1.2 Its Measurement

  本文提出两种度量过平滑的方式:$text{row-diff}$ 和  $text{col-diff}$。

  设 $mathbf{H}^{(k)} in mathbb{R}^{n times d}$ 为第 $k$ 个图卷积后的节点表示矩阵,即 $mathbf{H}^{(k)}=tilde{mathbf{A}}_{mathrm{sym}}^{k} mathbf{X}$。设 $mathbf{h}_{i}^{(k)} in mathbb{R}^{d}$ 为 $mathbf{H}^{(k)}$ 的第 $i$ 行,$mathbf{h}_{. i}^{(k)} in mathbb{R}^{n}$ 为 $mathbf{H}^{(k)}$ 的第 $i$ 列。

  $text{row-diff}(  left.mathbf{H}^{(k)}right)$ 和 $text{col-diff}(  left.mathbf{H}^{(k)}right)$ 的定义如下:

    ${large operatorname{row}-operatorname{diff}left(mathbf{H}^{(k)}right) =frac{1}{n^{2}} sumlimits _{i, j in[n]}left|mathbf{h}_{i}^{(k)}-mathbf{h}_{j}^{(k)}right|_{2}}  quadquadquad(2)$

    ${large operatorname{col-diff}left(mathbf{H}^{(k)}right) =frac{1}{d^{2}} sumlimits _{i, j in[d]}|    frac{mathbf{h}_{cdot i}^{(k)}}{|mathbf{h}_{cdot i}^{(k)}|_{1}}-frac{mathbf{h}_{cdot j}^{(k)}}{|mathbf{h}_{cdot j}^{(k)}|_{1}}   |_{2}} quadquadquad(3) $

  $text{row-diff}$ 量化节点之间的成对距离,而 $text{col-diff}$ 特征之间的成对距离。

2.2 Studying oversmoothing with SGC

  GCN 过平滑可能由于层数增加导致的性能下降,即添加更多的层导致更多的参数(添加的线性层 存在 $mathbf{W}^{(k)}$)容易导致过拟合。同样层数增加,容易存在反向传播梯度的消失(应该指的是参数多)。

  将层数增加影响过平滑和 使用参数导致过拟合即反向传播梯度消失 解耦。本文使用 SGC ,一种简化的 GCN :去除图卷积层的所有投影参数和所有层间的非线性激活。SGC可写为:

    $widehat{boldsymbol{Y}}=operatorname{softmax}left(tilde{mathbf{A}}_{mathrm{sym}}^{K} mathbf{X} mathbf{W}right) quadquadquad(4) $

  其中,$K$ 为图卷积的个数,$mathbf{W} in mathbb{R}^{d times c}$ 表示可学习参数。

  Note:SGC有一个固定数量的参数,不依赖于图卷积的数量(即层),也因此防止了过拟合和消失梯度问题的影响。

  那么,这只给我们留下了过平滑作为随着 $K$ 增加的性能下降的可能原因。需要注意的是 SGC 并不是一种牺牲,在某些分类任务似乎有更好或者相似的准确性。

  Figure 1 中的虚线说明了当增加层数( $K$ )时,SGC 在 Cora 数据集上的性能。训练(交叉熵)损失随着 $K$ 的增大而单调地增加,这可能是因为图卷积将节点表示与它们的邻居混合在一起,使它们变得不那么容易区分(训练变得更加困难)。另一方面,至多到 $K=4$,图卷积(即平滑)提高了泛化能力,减少了训练和验证/测试损失之间的差距,之后,过平滑开始影响性能。$text{row-diff}$ 和 $text{col-diff}$ 都随 $K$ 继续单调递减,为过平滑提供了支持证据。

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图

3 Tackling oversmoothing

3.1 Proposed pairnorm

  考虑图正则化最小二乘(GRLS):设 $overline{mathbf{X}} in mathbb{R}^{n times d}$ 是节点表示矩阵,其中 $overline{mathbf{x}}_{i} in mathbb{R}^{d}$ 表示 $overline{mathbf{X}}$ 的第 $i$ 行,GRLS 问题为:

    $underset{overline{mathbf{x}}}{text{min}} sumlimits _{i in mathcal{V}}left|overline{mathbf{x}}_{i}-mathbf{x}_{i}right|_{tilde{mathbf{D}}}^{2}+sumlimits_{(i, j) in mathcal{E}}left|overline{mathbf{x}}_{i}-overline{mathbf{x}}_{j}right|_{2}^{2}quadquadquad(5)$

  其中:

    • $left|mathbf{z}_{i}right|_{tilde{mathbf{D}}}^{2}=mathbf{z}_{i}^{T} tilde{mathbf{D}} mathbf{z}_{i}$;

  第一项可以看作是度加权最小二乘,第二个是一个图正则化项,度量新特征在图结构上的变化。

  优化问题的目标可认为是估计新的 “去噪” 特征 $overline{mathbf{x}}_{i}$ 离输入特征 $mathbf{x}_{i}$ 不远,并且在图结构上很平滑。

  GRLS 问题有一个封闭形式的解 $overline{mathbf{X}}=left(2 mathbf{I}-tilde{mathbf{A}}_{mathrm{rw}}right)^{-1} mathbf{X}$,其中 $tilde{mathbf{A}}_{mathrm{rw}} mathbf{X}$ 是一阶泰勒近似,即 $tilde{mathbf{A}}_{mathrm{rw}} mathbf{X} approx overline{mathbf{X}}$。通过替换 $tilde{mathbf{A}}_{mathrm{rw}}$ 为 $tilde{mathbf{A}}_{text {sym }}$,得到与图卷积相同的形式,即 $tilde{mathbf{X}}=tilde{mathbf{A}}_{text {sym }} mathbf{X} approx overline{mathbf{X}}$。因此,图卷积可以看作是 $text{Eq.5}$ 的近似解,它最小化了图结构上的变化,同时保持新的表示接近原始表示。

  理想情况下,希望获得对同一集群内的节点的平滑,但是避免平滑来自不同集群的节点。$text{Eq.5}$ 中的目标通过图正则化项只优化第一个目标。因此,当重复应用卷积时,它容易出现过平滑。为规避这个问题并同时实现这两个目标,可以添加一个负项,如没有边连接对之间的距离之和如下:

    $underset{overline{mathbf{x}}}{text{min}}  sumlimits _{i in mathcal{V}}left|overline{mathbf{x}}_{i}-mathbf{x}_{i}right|_{tilde{mathbf{D}}}^{2}+sumlimits_{(i, j) in mathcal{E}}left|overline{mathbf{x}}_{i}-overline{mathbf{x}}_{j}right|_{2}^{2}-lambda sum_{(i, j) notin mathcal{E}}left|overline{mathbf{x}}_{i}-overline{mathbf{x}}_{j}right|_{2}^{2}quadquadquad(6)$

  同样,可通过推导 $text{Eq.6}$ 的封闭型解并用一阶泰勒展开进行逼近,得到一个具有超参数 $lambda$ 的修正图卷积算子。

  在本文中,没有提出了一个全新的图卷积算子,而是提出了一个通用的、有效的 “补丁”,称为 PAIRNORM,它可以应用于具有过平滑潜力的任何形式的图卷积。

  设 $tilde{mathbf{X}}$(图卷积的输出)和 $dot{mathbf{X}}$ 分别为 PAIRNORM 的输入和输出。观察到图卷积 $tilde{mathbf{X}}=tilde{mathbf{A}}_{text {sym }} mathbf{X}$ 的输出实现了第一个目标 度加权,PAIRNORM 作为一个标准化层,在 $tilde{mathbf{X}}$ 上工作,以实现第二个目标,即保持未连接的对表示更远。具体来说,PAIRNORM 将 $tilde{mathbf{X}}$ 归一化,使总成对平方距离 $operatorname{TPSD}(dot{mathbf{X}}):=sumlimits_{i, j in[n]}left|dot{mathbf{x}}_{i}-dot{mathbf{x}}_{j}right|_{2}^{2} $ 和 $operatorname{TPSD}(mathbf{X} )$ 一样:

    $ sumlimits_{(i, j) in mathcal{E}}left|dot{mathbf{x}}_{i}-dot{mathbf{x}}_{j}right|_{2}^{2}+sumlimits_{(i, j) notin mathcal{E}}left|dot{mathbf{x}}_{i}-dot{mathbf{x}}_{j}right|_{2}^{2}=sumlimits_{(i, j) in mathcal{E}}left|mathbf{x}_{i}-mathbf{x}_{j}right|_{2}^{2}+sumlimits_{(i, j) notin mathcal{E}}left|mathbf{x}_{i}-mathbf{x}_{j}right|_{2}^{2} quadquadquad(7)$
  理想情况下,希望 $sumlimits _{(i, j) notin mathcal{E}}left|dot{mathbf{x}}_{i}-dot{mathbf{x}}_{j}right|_{2}^{2}$ 和 $sumlimits _{(i, j) notin mathcal{E}}left|mathbf{x}_{i}-mathbf{x}_{j}right|_{2}^{2}$ 一样大,$sumlimits _{(i, j) in mathcal{E}}left|dot{mathbf{x}}_{i}-dot{mathbf{x}}_{j}right|_{2}^{2} approx sumlimits _{(i, j) in mathcal{E}}left|tilde{mathbf{x}}_{i}-tilde{mathbf{x}}_{j}right|_{2}^{2}$ 是由于拉普拉斯平滑的原因。

  实践中,不需要时刻关注 $operatorname{TPSD}(mathbf{X} )$ 的值,只需要在所有层使得 $operatorname{TPSD}(mathbf{X} )$ 保持一个恒定的常量 $C$。

  为计算 $operatorname{TPSD}(mathbf{X} )$ 的常数值,可先计算 $operatorname{TPSD}(tilde{mathbf{X}})$。当然直接计算 $operatorname{TPSD}(tilde{mathbf{X}})$ 涉及到 $n^{2}$ 个成对的距离 $mathcal{O}left(n^{2} dright)$,这对大数据集来说是十分耗时间的。

  同样地,规范化可以通过一个两步的方法来完成,其中  $operatorname{TPSD}$ 被重写为

    $operatorname{TPSD}(tilde{mathbf{X}})=sumlimits_{i, j in[n]}left|tilde{mathbf{x}}_{i}-tilde{mathbf{x}}_{j}right|_{2}^{2}=2 n^{2}left(frac{1}{n} sumlimits_{i=1}^{n}left|tilde{mathbf{x}}_{i}right|_{2}^{2}-left|frac{1}{n} sumlimits_{i=1}^{n} tilde{mathbf{x}}_{i}right|_{2}^{2}right) quadquadquad(8)$

  $text{Eq.8}$ 的第一项 表示节点表示的均方长度,第二项描述了节点表示的均值的平方长度。

  为简化 $text{Eq.8}$ 的计算,令每个 $tilde{mathbf{x}}_{i}$ 减去行均值 $tilde{mathbf{x}}_{i}^{c}=tilde{mathbf{x}}_{i}-frac{1}{n} sumlimits _{i}^{n} tilde{mathbf{x}}_{i}$,其中 $tilde{mathbf{x}}_{i}^{c}$ 表示中心表示。这种移动不会影响 $operatorname{TPSD}$,并且驱动了项 $left|frac{1}{n} sumlimits _{i=1}^{n} tilde{mathbf{x}}_{i}right|_{2}^{2} $ 趋近 $0$。那么,计算 $operatorname{TPSD}(tilde{mathbf{X}}) $ 可归结为计算 $tilde{mathbf{X}}^{c}$ 的 $F$ 范数的平方,并有 $mathcal{O}(n d)$:

    $operatorname{TPSD}(tilde{mathbf{X}})=operatorname{TPSD}left(tilde{mathbf{X}}^{c}right)=2 nleft|tilde{mathbf{X}}^{c}right|_{F}^{2} quadquadquad(9)$

 $text{Eq.9}$ 可以写成一个两步的、中心和规模的归一化过程:

    $tilde{mathbf{x}}_{i}^{c}=tilde{mathbf{x}}_{i}-frac{1}{n} sumlimits _{i=1}^{n} tilde{mathbf{x}}_{i}  quadquadtext{(Center)}quad(10)$

    $dot{mathbf{x}}_{i}=s cdot frac{tilde{mathbf{x}}_{i}^{c}}{sqrt{frac{1}{n} sumlimits_{i=1}^{n}left|tilde{mathbf{x}}_{i}^{c}right|_{2}^{2}}}=s sqrt{n} cdot frac{tilde{mathbf{x}}_{i}^{c}}{sqrt{left|tilde{mathbf{X}}^{c}right|_{F}^{2}}} quadquadtext{(Scale)}quad(11)$

  缩放后,数据保持中心化 $left|sumlimits _{i=1}^{n} dot{mathbf{x}}_{i}right|_{2}^{2}=0$ 。在 $text{Eq.11}$ 中,$s$ 是一个超参数,它决定了 $C$。具体来说,

    $operatorname{TPSD}(dot{mathbf{X}})=2 n|dot{mathbf{X}}|_{F}^{2}=2 n sumlimits_{i}left|s cdot frac{tilde{mathbf{x}}_{i}^{c}}{sqrt{frac{1}{n} sumlimits_{i}left|tilde{mathbf{x}}_{i}^{c}right|_{2}^{2}}}right|_{2}^{2}=2 n frac{s^{2}}{frac{1}{n} sumlimits_{i}left|tilde{mathbf{x}}_{i}^{c}right|_{2}^{2}} sumlimits_{i}left|tilde{mathbf{x}}_{i}^{c}right|_{2}^{2}=2 n^{2} s^{2} quad(12)$

   然后,$dot{mathbf{X}}:=operatorname{PAIRNORM}(tilde{mathbf{X}})$ 拥有行均值为 $0$ (Center),和恒定的总成对平方距离 $C=2 n^{2} s^{2}$。在 Figure 2 中给出了一对范数的说明。PAIRNORM 的输出被输入到下一个卷积层。

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图1

  本文还推导出 PAIRNORM 的变体,即通过替换 $text{Eq.11}$ 的 $sumlimits _{i=1}^{n}left|tilde{mathbf{x}}_{i}^{c}right|_{2}^{2} $ 为 $nleft|tilde{mathbf{x}}_{i}^{c}right|_{2}^{2}$ ,本文称之为 PAIRNORM-SI ,此时所有的节点都有相同的 $L_{2}$ 范数 $s$ 。

  在实践中,发现 PAIRNORM 和 PAIRNORM-SI 对 SGC 都很有效,而 PAIRNORM-SI 对 GCN 和 GAT 提供了更好和更稳定的结果。GCN 和 GAT 需要更严格的归一化的原因可能是因为它们有更多的参数,更容易发生过拟合。在所有实验中,对SGC采用PAIRNORM,对 GCN 和 GAT 采用 PAIRNORM-SI。

  Figure 1 中的实线显示了 SGC 性能, 与 “vanilla” 版本相比,随着层数的增加,我们在每个图卷积层之后使用 PAIRNORM。类似地,Figure 3 用于 GCN 和 GAT(在每个图卷积激活后应用PAIRNORM-SI)。请注意,PAIRNORM 的性能衰减要慢得多。

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图2

  虽然 PAIRNORM 使更深层次的模型对过度平滑更稳健,但总体测试精度没有提高似乎很奇怪。事实上,文献中经常使用的基准图数据集需要不超过 $4$ 层,之后性能就会下降(即使是缓慢的)。

3.2 A case where deeper GNNs are beneficial

  如果一个任务需要大量的层来实现其最佳性能,那么它将更多的收益于使用 PAIRNORM,为此本文研究了 “missing feature setting”,即节点的一个子集存在特征缺失。

  假设 $mathcal{M} subseteq mathcal{V}_{u}$ 代表特征缺失子集,其中 $forall m in mathcal{M}$,$mathbf{x}_{m}=emptyset $。本文设置 $p=|mathcal{M}| /left|mathcal{V}_{u}right|$ 代表缺失比例。将这种任务的变体称为具有缺失向量的半监督节点分类(SSNC-MV)。直观的说,需要更多的传播步骤才能恢复这些节点有效的特征表示。

  Figure 4 显示了随着层数的增加,SGC、GCN 和 GAT 模型在 Cora 上的性能变化,其中我们从所有未标记的节点中删除特征向量,即 $p=1$。与没有PAIRNORM 的模型相比,具有 PAIRNORM 的模型获得了更高的测试精度,它们通常会达到更多的层数。

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图3

4 Experiments

  在本节中,我们设计了广泛的实验来评估在SSNC-MV设置下的SGC、GCN和GAT模型的有效性。

4.1 Experiment setup

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图4

4.2 Experiment results

核心代码:

if __name__ == "__main__":
    mode = 'PN'
    scale = 1
    x  =torch.randint(0,10,(3,2)).type(torch.float)
    col_mean = x.mean(dim=0)
    if mode == 'PN':
        x = x - col_mean
        print("x = ",x)
        rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
        x = scale * x / rownorm_mean

    if mode == 'PN-SI':
        x = x - col_mean
        rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
        x = scale * x / rownorm_individual

    if mode == 'PN-SCS':
        rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
        x = scale * x / rownorm_individual - col_mean

节点分类

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图5

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图6

  论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图7

代码以 Deep_GCN 为例子:

class DeepGCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, nlayer=2, residual=0,
                 norm_mode='None', norm_scale=1, **kwargs):
        super(DeepGCN, self).__init__()
        assert nlayer >= 1 
        self.hidden_layers = nn.ModuleList([
            GraphConv(nfeat if i==0 else nhid, nhid)  for i in range(nlayer-1)
        ])
        self.out_layer = GraphConv(nfeat if nlayer==1 else nhid , nclass)

        self.dropout = nn.Dropout(p=dropout)
        self.dropout_rate = dropout
        self.relu = nn.ReLU(True)
        self.norm = PairNorm(norm_mode, norm_scale)
        self.skip = residual

    def forward(self, x, adj):
        x_old = 0
        for i, layer in enumerate(self.hidden_layers):
            x = self.dropout(x)
            x = layer(x, adj)
            x = self.norm(x)
            x = self.relu(x)
            if self.skip>0 and i%self.skip==0:
                x = x + x_old
                x_old = x
            
        x = self.dropout(x)
        x = self.out_layer(x, adj)
        return x

5 Conclusion

  提出了一种有效防止过平滑问题的 成对范数 ,一种新的归一化层,提高了深度 GNNs 对过平滑的鲁棒性。

6 Reason of failure

  即实验对于 mask feature 只处理了一次,并没有在每个 epoch 中进行处理。

论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图8论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》插图9

def load_data(data_name='Cora', normalize_feature=True, missing_rate=0, cuda=False):
    # can use other dataset, some doesn't have mask
    print(os.path.join(DATA_ROOT, data_name))
    dataset = geo_data.Planetoid(DATA_ROOT, data_name)
    print("dataset = ",dataset)
    # print(dataset[0])
    # print(dataset.data)
    data = geo_data.Planetoid(DATA_ROOT, data_name).data

    # original split
    data.train_mask = data.train_mask.type(torch.bool)
    data.val_mask = data.val_mask.type(torch.bool)
    # data.test_mask = data.test_mask.type(torch.bool)    
    # expand test_mask to all rest nodes 
    data.test_mask = ~(data.train_mask + data.val_mask)
    # get adjacency matrix
    n = len(data.x)
    adj = sp.csr_matrix((np.ones(data.edge_index.shape[1]), data.edge_index), shape=(n,n))
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) + sp.eye(adj.shape[0])
    adj = normalize_adj_row(adj) # symmetric normalization works bad, but why? Test more. 
    data.adj = to_torch_sparse(adj)
    # normalize feature
    if normalize_feature:
        data.x = row_l1_normalize(data.x)
    
    # generate missing feature setting 
    indices_dir = os.path.join(DATA_ROOT, data_name, 'indices')
    if not os.path.isdir(indices_dir): 
        os.mkdir(indices_dir)
    missing_indices_file = os.path.join(indices_dir, "indices_missing_rate={}.npy".format(missing_rate))
    if not os.path.exists(missing_indices_file):
        erasing_pool = torch.arange(n)[~data.train_mask] # keep training set always full feature
        size = int(len(erasing_pool) * (missing_rate/100))
        idx_erased = np.random.choice(erasing_pool, size=size, replace=False)
        np.save(missing_indices_file, idx_erased)
    else:
        idx_erased = np.load(missing_indices_file)
    # erasing feature for random missing 
    if missing_rate > 0:
        data.x[idx_erased] = 0
    
    if cuda:
        data.x = data.x.cuda()
        data.y = data.y.cuda()
        data.adj = data.adj.cuda()
    
    return data   

View Code

文章来源于互联网:论文解读(PairNorm)《PairNorm: Tackling Oversmoothing in GNNs》

THE END
分享
二维码