TL:DR

デフォルトで用意されているパラメータbatch_firstを適切に設定しないとleakageしまくるという話

背景

  • ランキングモデルにselfattentionを組み込もうとした時に,急にMRR=83%とか出てくる
  • 評価時の関数に問題があるかどうか確認したが,自分では問題を発見できなかった
  • attentionを組み込んだ際のみにこの現象が確認されたので,attention部分の自分の実装に問題がある

問題点

pytorchのデフォルトのMultiheadAttentionでは,入力として[L, N, D]形式のテンソルが想定されている

引数batch_firstをつけないと,系列長部分をbatchとして解釈してしまうため,バッチ内のleakageが発生する.(内部計算で系列長方向に参照しまくるため,バッチ内の他データを参照してしまう)

参考文献

公式リファレンス