TFX: A TensorFlow-Based Production-Scale Machine LearningPlatform 読んだメモ
Creating and maintaining a platform for reliably producingand deploying machine learning models requires careful or-chestration of many components—a learner for generatingmodels based on training data, modules for analyzing and val-idating both data as well as models, and finally infrastructurefor serving models in production. This becomes particularlychallenging when data changes over time and fresh modelsneed to be produced continuously. Unfortunately, such or-chestration is often done ad hoc using glue code and customscripts developed by individual teams for specific use cases,leading to duplicated effort and fragile systems with hightechnical debt.
↑を解決するGoogleの基盤TFX
TensorFlowベースの汎用機械学習プラットフォームであるTensorFlow Extended(TFX)
やったこと・改善点
- やったこと
- モデルを作るプラットフォームを作った(TFX)
- コンポーネントの標準化
- プラットフォーム構成の簡素化
- disruptionsを最小限に抑えるプラットフォームの安定性を実現
- 改善点
- 再利用性のない実装からの回避(技術的負債の軽減)
- 実験サイクルの高速化 数か月から数週間の生産時間の短縮
- データおよびモデル分析の改善によりアプリのインストールが2%増加
モデルを作るプラットフォームの複雑さとは
- Building one machine learning platform for many different learning tasks
- Continuous training and serving
- Human-in-the-loop
- UIを公開して,エンジニアが簡単にデプロイ・監視できるように
- データサイエンティストがデータとモデルを理解・分析するのを支援する必要性
- 新しいデータに対して合理的に動作するかどうかを予測することは困難である[8]
- Production-level reliability and scalability
- 一貫性のないデータ,ソフトウェア,ユーザーコンフィグによるdisruptions,および基礎となる実行環境の障害による失敗に対してreliable
- model validation, data validation:変な訓練データ,不良なモデルの検証
- 大量のデータ,サービスシステムへの運用トラフィックの増加にscalable
- 一貫性のないデータ,ソフトウェア,ユーザーコンフィグによるdisruptions,および基礎となる実行環境の障害による失敗に対してreliable
component詳細
data analysis (Sections 3.1)
特徴量に関する記述統計を出力する
データをセグメントを切って個別に見ることも可能
- バイナリ分類問題の正例,不例
- 特徴量ごとの相関,共分散
data transformation (Section 3.2)
特徴量のラングリングを可能にする
- vocabularies: feature-to-integer mappings
- など
訓練と予測で一貫性のある同じ変換ロジックを使わないとダメ
- エンコーダ系
モデルの一部として変換を出力することでこれを回避
data validation (Section3.3)
スキーマによるデータの異常検出
機能
- どんな異常が検出され,その範囲が一目でわかるようにする
- 各異常には,ユーザーがデータをデバッグおよび修正する方法を理解するのに役立つ簡単な説明が必要
- 特徴量の値が特定の範囲外であるという異常
- 予想される分布と実際の分布のKLの相違がしきい値を超えたという異常→これはデバックしづらい
- 適切にスキーマの変更を提案できるか(特徴料のドメインの変更)
- 新しいユニーク値の提案
- 範囲の増加
- 異常の文書化・追跡・報告がきちんとしたフォーマットで出力される
検出項目
trainer (Section 4)
すべてのトレーニングユースケースをサポートできるプロダクション品質モデルのトレーニングプロセスを合理化(および可能な限り自動化)する
warm-starting
これを活用してあまり多くのリソースを消費せずに高品質のモデルを実現
前提
実際には転移学習
the ability to selectively warm-start selected features of the network was identified asa crucial component and its implementation in TensorFlow was subsequently open sourced.
他にも,TensorFlowはさまざまな学習手法(ディープラーニング、ワイドアンドディープラーニング、シーケンスラーニングなど)を使用してモデルトレーニングを構成するための高レベルな統合APIを提供
High-Level Model Specification API
ベストプラクティスをエンコードする高レベルの抽象化レイヤーで実装の詳細を隠す
- FeatureColumns:ユーザーが機械学習モデルのどの機能に焦点を当てるのに役立つ
- Estimator:trainとevaluateが絶対にある
model evaluation and validation (Section 5)
- evaluation: ビジネスメトリックに対するオフライン評価
- データのセグメントを切って比較可能(country=USなど)
- validation: カナリアリリースによって品質をしきい値とベースラインモデルとの比較
いいモデルとは?
- モデルは安全に提供可能
- ロード時,または不良・予期しない入力がきたときに,クラッシュしたりエラーを引き起こさない
- ライブラリのバージョンなど
- あまり多くのリソース(CPUやRAMなど)を使用しない
- ロード時,または不良・予期しない入力がきたときに,クラッシュしたりエラーを引き起こさない
- モデルは一定以上の予測の品質を持つ
課題点:モデルの動作で予想される変化と予期しない変化を区別するのが難しい
メトリックの変化に保守的すぎると警告が増え,狼少年問題
→緩いしきい値に引っかかるものを警告すればバグは検知できるかも
ユーザはこの機能を欲しがってたわけじゃない(別にモデルの精度が上がるわけじゃないし,むしろ手間がかかる)
→基盤チームから出て来た機能
→実際に検証で防げたはずの問題に遭遇すると使われ始めた
serving (Section 6)
TensorFlow Servingの話
機械学習をサービングするためのカスタマイズ可能なフレームワークを提供することで,機械学習モデルのプロダクショングレードの提供システムを展開するために必要な定型コードの削減を目指す
プロダクションに乗せるために必要な要素 - lowlatency - high efficiency - horizontal scalability - reliability - robustness
Multitenancy with Isolation TF ServingにおいてのMultitenancy:複数の機械学習モデルを同時に提供可能にすること
model-isolation モデル分離
システムが新しいモデルをロードしているときに多数のクエリがきたら?
→モデルの読み込み操作用に分離された専用スレッドプールの構成を可能にする機能をTensorFlowに実装.呼び出し元が指定したスレッドプールで任意の操作を実行
今まで高負荷時にモデルロードが走った時のレスポンスタイムの99.9パーセンタイルのレイテンシは約500〜1500msecだったが,↑の機能により75~150msecに
Fast Training Data Deserialization ニューラルネットワーク以外のモデル(線形モデルなど)は、CPU集約型よりもデータ集約型 このようなモデルでは,データ入力・出力,前処理がボトルネックになる傾向
→汎用プロトコルバッファパーサーの使用は非効率的だった
→複数の解析構成でのさまざまな実際のデータ分布のプロファイルに基づいて,専用のプロトコルバッファパーサーを構築
2〜5倍の高速化に成功
Argo Workflowをローカル環境で使ってみる
Argo について
Argoはコンテナベースのワークフローエンジンで,ワークフローの各ステップをコンテナとして実装することを可能にします.
つまり,各ステップはdockerのイメージを使用して実行されます.
また,ワークフローはCustom Resource Definitionで定義します.
そのため,kubenetesのマニフェスト管理と同様にhelmなどを用いてワークフロー自体を管理することができます.
GoogleやGithub,PFNなど世界でも有数の企業がArgoを使用しています.
Argo公式ブログがより詳しいです.
Introducing Argo — A Container-Native Workflow Engine for Kubernetes
Argoのインストール
Argoを導入するための下準備
ここではminikubeを使って,kubenetesのクラスタをローカルに構築します.
minikubeを使うことでGoogle Kubenetes Engineなどを使うことなく無料でkubenetesクラスタを構築できます.
手元で試したいときに非常に便利ですね.
minikubeのインストールはHomebrewで一発完了です.
brew cask install minikube minikube version
minikubeを起動してみましょう.
minikube start
初めての起動が正常にいけば以下のような表示がえられるでしょう.
次にkubectlでコンテキストを確認してkubectlがminikubeに向いていることを確認できたらminikubeの導入は完了です.
kubectl config current-context
Argoのインストール
これでArgoをインストールする準備は整いました.
まず,はじめにArgo CLIをインストールします.
brew install argoproj/tap/argo argo version
次に,Namespaceを作り,argoのmanifestを反映します.
kubectl create ns argo kubectl apply -n argo -f https://raw.githubusercontent.com/argoproj/argo/v2.2.1/manifests/install.yaml
次に default
サービスアカウントに権限を付与します.
これをしないと,Argo Workflowの一部の機能(ファイル出力,secretへのアクセスなど)が使えません.
kubectl create rolebinding default-admin --clusterrole=admin --serviceaccount=default:default
ここまでくれば導入完了も同然です. port forwardをしてargo UIにアクセスしてみましょう.
kubectl -n argo port-forward deployment/argo-ui 8001:8001
http://localhost:8001/workflows
では,実際にworkflowを実行してみましょう.
argo submit --watch https://raw.githubusercontent.com/argoproj/argo/master/examples/hello-world.yaml
apiVersion: argoproj.io/v1alpha1 kind: Workflow metadata: generateName: hello-world- spec: entrypoint: whalesay templates: - name: whalesay container: image: docker/whalesay:latest command: [cowsay] args: ["hello world"]
このworkflowでは dockerのwhalesayイメージを指定し,cowsayコマンドを実行します.
watchオプションをつけることで実行状況をterminal上で監視できます.
workflowは一つのpodとして実行されるので argo list
または kubectl get po
で一覧を取得することができます.
実行したworkflowはもちろんargo UIからも確認できます.
workflowをクリックすることで詳細を確認できます.
YAML
からsubmitしたworkflow,logからworkflowからの標準出力を確認することができます.
次は別のworkflowの例です.
argo submit --watch https://raw.githubusercontent.com/argoproj/argo/master/examples/coinflip-recursive.yaml
apiVersion: argoproj.io/v1alpha1 kind: Workflow metadata: generateName: coinflip-recursive- spec: entrypoint: coinflip templates: - name: coinflip steps: - - name: flip-coin template: flip-coin - - name: heads template: heads when: "{{steps.flip-coin.outputs.result}} == heads" - name: tails template: coinflip when: "{{steps.flip-coin.outputs.result}} == tails" - name: flip-coin script: image: python:alpine3.6 command: [python] source: | import random result = "heads" if random.randint(0,1) == 0 else "tails" print(result) - name: heads container: image: alpine:3.6 command: [sh, -c] args: ["echo \"it was heads\""]
このworkflowでは,前回ステップの結果をもとに次のステップを変更しています.
また,templateを定義し,再利用することができます.
初めのステップでpythonのdockerイメージからrandintを実行します.
その結果を次のステップの when
で受け取り,条件によってはもう一度同じテンプレートからステップを実行します.
これによって再帰的にステップを実行することも可能です.
今回はローカルにインストールしたminukubeにArgoをインストールし,2つの例を実行しました.
Argoでできることはここで紹介したことにとどまりません.
次回以降,豊富なworkflow例と共にcookbookを作成・紹介していきたいと思っているので引き続きよろしくお願いします.
参考
書きました
新人ブログリレーということ以下の記事を書きました.
あなたの生産性を向上させるJupyter notebook Tips | リクルートテクノロジーズ メンバーズブログ
他の同期は入社してからの業務を中心に記事を書いています.
私の記事は自分が持っているjupyterについての知識を体系化しました.
お役に立てば幸いです!
YANS2018参加記録
久しぶりに投稿します.
8/27(月)-8/29(水)に NLP若手の会(通称: YANS) 第13回シンポジウム(2018) に参加して来ました.
いくつか気になったポスター発表のまとめを以下に掲載します.
Embeddingの圧縮アツい!
GW開発合宿でGoのレコメンドエンジンを作った
開発合宿
GWを利用して同期の有志5人で開発合宿をしました.
場所は千葉県の東浪見駅から徒歩20分ほどの民泊で海まで徒歩5分という感じです.
東浪見駅はめちゃめちゃ自然あふれる場所で最寄りのコンビニも数キロ先という都会の喧騒とはかけ離れた場所でした.
開発は個人で勝手にテーマを決めてやっていくスタイルです.
GAEやElixir, CTFなど各々がやりたいことをやりました.
僕は今更ながらGo入門してました.
やったこと
作ったのがこれです.
ユーザ/アイテムベースの協調フィルタリングを行うことができる簡素なレコメンドシステムです.
具体的には,ユーザのアイテム評価データからアイテム評価行列を作ってユーザ/アイテムの類似度を計算し, 類似のユーザ/アイテムを取得可能というものです.
類似度関数は以下をサポートしています.
普段は機械学習周りをやっててPythonなどリッチな言語を使ってライブラリで済んでしまうことも多いのでこういうのもいいです.
疲れ時は海に散歩に行ったり,アコギ持ってきてる人がいてみんなで歌ったり,終始酒を飲みながら開発をして最高でした.
Notionが便利すぎてぜひオススメしたい話
こんにちは.最近は,入社して社会になってます.
Notion*1
自分の考えをまとめたり,チームでの共有事項をまとめたりするためのツールはたくさんあります.
Google DocsやGithub Wiki,TrelloやDropbox Paper,どれも素晴らしいツールですが,色々導入しすぎて「今回の案件にはどれが最適なのか...」ってなる時ありませんか?
それを解決してくれるのがNotionです.
Notionとは All-in-one workspace
を謳った上記ツールの機能を網羅してくれるツールです.
概要
Notionは,web,Desktop,iOSなど各プラットフォームにも対応しており,電車での移動中などでも簡単に編集・確認ができます.
コードを含めたドキュメントはもちろん,kanbanや表,カレンダーを作成することができます.
これらの要素は,インラインでページ中に作成することもでき,また,一つの独立したページとして作ることも可能です.
非常に理に適った動きをしてくれて,ストレスフリーな設計となっています.
数式もTeX形式で書くことができて簡単です.
kanban機能もあります.
表やカレンダーも作成することができます.
作成したドキュメントやkanban,表・カレンダーは簡単に構造化して管理することができます.
もちろん,権限があれば複数人での編集も可能です.
論文のまとめと実装の管理の例*2
↓↓↓論文のまとめと実装の管理の例を公開してみたのでぜひ触ってみてください↓↓↓
https://www.notion.so/lapiszero09/Home-b322ed09f90e405e9a9d904709fb37fb
↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑
上のリンクからnotionのデモを触れるので,「論文まとめ」や「実装」などをクリックしてみてください.
An Adaptive Version of the Boost by Majority Algorithm
An Adaptive Version of the Boost by Majority Algorithm
Freund, Y.: An adaptive version of the boost by majority algorithm, Machine Learning, vol. 43, no. 3, pp. 293‒318, 2001.
実装はこちらです.
以前ゼミで,Introduction to ensemble methods for beginners Introduction to ensemble methods for beginners
というタイトルで以下のようなアンサンブルについての概要を簡単にまとめた資料を発表しました(突貫工事なのでまとまりがないのは目を瞑っていただきたく).
その中でも,History of Boosting節のBrownBoost*1について今回は紹介します.
BrownBoost
本論文では,AdaBoostがノイズに弱い*2という特徴を考慮したBrownBoostを提案しています.
本論文以外でも,ノイズ耐性を持つBoostingは提案されています*3.
しかし,PACの枠組みの中で定義されている正式なブースティング特性はありません.
そこで,FreundはBoosting-By-Majorityを適用することでノイズ耐性を受け継いだPACに基づいたBrownBoostを提案しました.
Boosting-By-Majority*4
Boosting-By-Majority(以下,BBM)は,Freundによって提案された最初のブースティングアルゴリズムです.
また,ノイズ耐性*5があります.
しかし,BBMを使用するためには,基本学習器の誤差限界パラメータとして与える必要があり,実用には向きません.
BBMとBrownBoostの関係
このパラメータを削除した手法がBrownBoostです.
具体的には,パラメータを削除し,候補モデルの誤差に適応的に最終的なアンサンブルを生成します.
BrownBoostでは,候補モデルの数を指定するパラメータはなく,代わりにというパラメータがあります.
このパラメータによって”残り時間”が決められます.
この”残り時間”は,BrownBoostのイテレーションに関するパラメータです.
また,候補モデルが訓練された時,”残り時間”と分類マージンに基づいてそのモデルの重みが決まります.
そして,その重みに基づいて”残り時間”は減算されます.
この分類マージンにより,BBMやBrownBoostでは各インスタンスを正解するように候補モデルを訓練することができます.
一方で,各候補モデルの正答率が低いインスタンスに対しては,負のマージンが割り当てられ,そのようなインスタンスはノイズであると判断されます.
これがBBM,BrownBoostがノイズに強い所以です.
BrownBoostのイテレーションは,”残り時間”が収束するまで続きます.
候補モデルの重み付き誤差によって決まるのでアルゴリズムの実行時間はBBMに比べて非常に長くなり得ます.
この問題に関しては,候補モデルの誤差を0.5に近づけた時のアンサンブルの振る舞いを考えることで答えることが可能です.
モデルの誤差をとし,時間におけるアンサンブルの出力をと表します.
この時,時間ををすると,BrownBoostの収束後の時間はとは無関係の定数となり,これにより下式で表現される”位置”を定義します.
モデルの誤差を0.5に近づけることが前提であり,この時の”位置”は下式で得られる平均と分散によって特徴付けられるブラウン運動による連続時間確率過程に近づきます.
ここで,は時刻におけるモデルの重み付き誤差を表します.
ここまでで誤差を0.5に近づけた時のアンサンブルの振る舞いがわかりました.
ここからはBBMで定義される重みの限界を考えます.
BBMの重み関数は二項分布であり,以下で表します.
ここで,はイテレーションの回数で,は現在のイテレーション番号です.
はこれまでの正しい予測の数です.
前述のとの定義から誤差を0.5に近づけるとの限界を得ることができます.
ここで,です.
同様に,誤差限界は下式のように表せます.
ここで,は誤差関数であり,下式で表現されます.
以上のように,,,及び,の定義が与えられると、BBMアルゴリズムを連続時間領域に変換できるようになります.
この領域では,BBMを実行する代わりに,候補モデルの誤差が0.5である分布に対応するの値を定義する微分方程式を解くことで近似解を得ることができます.
また,BrownBoostの損失関数は,に対応した形になり,同様に近似によって多項式時間での収束が可能となります.
ブラウン運動のシュミレーション
ブラウン運動は,連続時間ランダムウォークをモデル化したものです.
2次元ブラウン運動のシュミレーションを行い,その結果を可視化しました.
実装はここにあるので,遊んでみてください.
BrownBoost/brownian_motion.ipynb at master · lapis-zero09/BrownBoost · GitHub
BrownBoostの実装
BrownBoostのアルゴリズムは以下に示すとおりです.
アルゴリズム中ステップ3は論文に示されているとおり,Newton-Raphson法により解を導いています.
実装はこちらです.
実験
ここに簡単な実験を示しています.
BrownBoost/brownboost.ipynb at master · lapis-zero09/BrownBoost · GitHub
感想など
尻すぼみ感が否めませんが,明日も朝から研修なので許して.
後ほど追加実験をします.
*1:Freund, Y.: An adaptive version of the boost by majority algorithm, Machine Learning, vol. 43, no. 3, pp. 293‒318, 2001.
*2:Dietterich, T. G.: An experimental comparison of three methods for constructing ensembles of decision trees: Bagging, boosting, and randomization. Machine Learning, 40:2, 139–158.
*3:Friedman, J., Hastie, T., Tibshirani, R.: Additive logistic regression: a statistical view of boosting. Ann. Statist. 28 (2000), no. 2, pp. 337-407.
*4:Freund, Y.: Boosting aweaklearning algorithm by majority. Information and Computation, 121:2, 256–285.
*5:J. a. Aslam and S. E. Decatur, General bounds on statistical query learning and PAC learning with noise via hypothesis boosting, Proc. 1993 IEEE 34th Annu. Found. Comput. Sci., vol. 118, pp. 85‒118, 1993.
Model Compression
Model Compression
Bucilua, C., Caruana, R. and Niculescu-Mizil, A.: Model Compression, Proc. ACM SIGKDD, pp. 535–541 (2006).
https://dl.acm.org/citation.cfm?id=1150464
モデル圧縮
この論文はHintonらによるDistilling the Knowledge in a Neural Network*1の先駆けとなった論文です.
Once the cumbersome model has been trained, we can then use a different kind of training, which we call “distillation” to transfer the knowledge from the cumbersome model to a small model that is more suitable for deployment. A version of this strategy has already been pioneered by Rich Caruana and his collaborators.
Hintonらの"distillation"では,soft targetと呼ばれる教師モデルが各正解クラスに対して出力した確率分布を入力として与えます.
これにより,教師モデルが相対的に間違った答えを導き出した場合でも,生徒モデルは教師モデルの確率という情報を新たに得ることで正解できるように訓練できる場合があります.
この時,教師モデルに含まれる基本モデルの予測分布の算術または幾何平均をsoft targetとしています.
一方,Caruanaらの提案では,教師モデルによって予測したone-hotなクラスラベルを正解ラベルとして生徒モデルを学習します.
これだけでは,生徒モデルは元の訓練データで訓練できるモデルよりもいいパフォーマンスを発揮できません.
教師モデルの性能を維持したまま元の訓練データで訓練できる軽量なモデルよりも優れたモデルを得るために擬似データを使用しました.
Model Compression*2
Caruanaらは,同論文でMUNGEと呼ばれる擬似データ生成手法を提案し,訓練データから大量のラベルなし擬似データを生成し訓練に用いました.
ラベルなし擬似データは,教師モデルによってラベル付けし,生徒モデルの訓練に利用します.
MUNGEのアルゴリズムは以下の通りです.
MUNGEは見ての通り非常にシンプルな擬似データ生成手法です.
まず,訓練データ内のインスタンスについて,最近傍を発見します.
次に,確率パラメータに基づいて最近傍とインスタンスから得られる正規分布からランダムに値を得ます.
これを各インスタンスに対し,複数回行うことで任意数の擬似データを得ることができます.
MUNGEの実装はこちらです.
Caruanaらは,MUNGEを用いたモデル圧縮によって教師モデルと同等の精度を持ち,教師モデルよりも高速かつ軽量な生徒モデルを得ています.
簡単な実験
irisデータを使って擬似データ生成アルゴリズムMUNGEの簡単な実験を行います.
詳細はこちらです.
正規化したirisデータは以下のようにプロットできます.
このデータに対してMUNGEで擬似データを生成した結果が以下です.
良さそうですね.
しかし,確率パラメータを低くすると擬似データ数が多くなった場合,過学習する可能性が高くなります.
導出アルゴリズムからわかるように分散パラメータが小さすぎると明らかにおかしいデータが出来上がります.
一見良さそうですが,データが密集しすぎていて経験的にこのような場合は過学習することがあります.
いい感じです.
少し規則的に並び過ぎて人工データ感が否めません.
感想など
MUNGEはとてもシンプルなアルゴリズムであるが,データの特徴をよく捉えた擬似データを生成できます.
しかし,クラスの決定境界を意識して擬似データを生成しているわけではありません.
よって,上述のirisデータのプロットにおける青と緑のように境界面で接地したデータを分離しづらくなってしまいます.
あくまでも,モデル圧縮の文脈における擬似データ生成アルゴリズムなので教師モデルの特徴を活かしてこの辺りを改善したいですね.
*1:Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. In In Deep Learning and Representation Learning Workshop, NIPS, 2014.
*2:Bucilua, C., Caruana, R. and Niculescu-Mizil, A.: Model Compression, Proc. ACM SIGKDD, pp. 535–541 (2006).
Selective Ensemble under Regularization Framework
Selective Ensemble under Regularization Framework
Li, N. and Zhou, Z.-H.: Selective Ensemble under Regularization Framework, Proc. MCS, pp. 293–303 (2009).
実装はここに置いています.
アンサンブル枝刈り
アンサンブル枝刈りはensemble pruningとも呼ばれます.
アンサンブル枝刈りは訓練された複数の基本モデルが与えられた上で,その全てを結合するのではなく,部分集合を選択することです.
アンサンブル枝刈りにより,より小さなサイズのアンサンブルでより良い汎化性能を得ることが期待されます.
Tsoumakasらは,アンサンブル枝刈りを順序付けに基づく枝刈り,クラスタリングに基づく枝刈り,最適化に基づく枝刈りの3つのカテゴリーに分類しました*1.
この詳しい話については後ほど別の記事にしようと思ってます.
今回はその中でも最適化に基づく枝刈りに属するRSE(regularized selective ensemble)についてです.
RSE(regularized selective ensemble)*2
RSEはLiとZhouによって提案されたアンサンブル枝刈りをQP問題へと帰着させる手法です.
この手法の良いところは,従来のアンサンブル枝刈り手法よりもアンサンブルに含まれる基本モデルの数が少ないにもかかわらず,その性能がいいという点です.
また,正則化項にグラフラプラシアン正則化を用いており,半教師あり学習の場面でも使用することができます.
それでは詳しく見ていきましょう.
個のモデルに対し,アンサンブル結合重みベクトルをと定義します.
この時, かつ です.
RSEは,以下の正則化リスク関数を最小化することによりを決定します.
ここで,は訓練データに対する誤分類の経験損失で, は正則化項であり, はとの最小化における正則化パラメータを表す.
ヒンジ損失とグラフラプラシアン正則化項をそれぞれ経験損失と正則化として用いることにより,問題は下式で定式化されます.
ここで,は訓練データに対する個々のモデルの予測を表し, は全訓練データに対する全モデルの予測を集めた予測行列で, です.
は訓練データの近傍グラフの正規化グラフラプラシアンです.
グラフラプラシアンの詳細については,長くなってしまうので件の論文や他文献に譲ります.
行列の重み付け隣接行列をで表し,はの対角行列です.
この時,となります.
上式のは滑らかではので,スラック変数を導入することにより,上式は下式で書き表せます.
この時,上式は標準的なQP問題となり,従来の最適化パッケージを用いて,効率的に解くことができます.
また,という制約は、スパース誘導性を持つノルム制約となり,重みのいくつかの要素を強制的にゼロにします.
この特徴により,従来のアンサンブル枝刈り手法よりもアンサンブルに含まれる基本モデルの数が少ないにもかかわらず,それなりの性能を出すことができます.
また,簡単な例を代入すればわかりやすいのですが,正則化項の部分が似ているデータに対して違うラベルを割り当てると正則化するという振る舞いをします.
導出された結合重みベクトルを用いて,下式のようにアンサンブルの予測を得ます.
また,下式のようにの要素がゼロでない候補モデルの投票により予測を決定する選択的アンサンブルも提案されています.
少なくともerrorに関してはなかなかいい結果です.
errorだけの比較なのは,あまりいい感じはしませんが,アンサンブルの目指す所的にいいのかな
実装
実装はここに置いています.
最適化を行う前に,グラフラプラシアンに用いる行列を導出する必要があります.
これは,pythonでは遅すぎたのでJavaで実装しています.
しかし,訓練データはnpz形式で保存していたので,pythonでデータを読み込んでpy4j*3で起動したJavaサーバで計算を行っています.
/** * get the matrix W for Laplacian * @return w_link filename * @throws Exception */ public String[] getKernelLinkMatrix(byte[] data, String fold) throws Exception { System.out.println("\t\t[*] Hello from Java!"); // get the link matrix // create java matrix from numpy Instances d = createFromPy4j(data); Matrix LMat = new Matrix(m_numInst, m_numInst); RBFKernel krl = new RBFKernel(); double g = 0.5 / m_numAttr; krl.setGamma(g); krl.buildKernel(d); System.out.println("\t\t[*] calculating w_link on Java..."); for (int i = 0; i < m_numInst; i++) { for (int j = 0; j <= i; j++) { double vt = krl.eval(i, j, d.instance(i)); LMat.set(i, j, vt); LMat.set(j, i, vt); } } d = null; krl = null; System.out.println("\t\t[-] calc done..."); return returnFilename(LMat, m_numInst, fold); }
ラプラシアン行列の計算や重みの最適化はMatlabで行いました.
本件ではpythonよりも速く,また,最適化パッケージであるmosek*4の扱いが簡単だったためです.
mosekは論文でも採用されており,QP問題に関しては最速らしいです.
function Lap = LaplacianMatrix(W) % compute the laplacian matrix % % Input: % W: the line matrix % Output: % Lap: the laplacian matrix D = diag(sum(W,2)); N = D^(-0.5); Lap = N *(D-W)* N;
function weight = compute_weight(lambda, Y, Prd, Q, dirname); % compute the weights for the base classifiers % % Input: % Cfr: the ensemble classifier, which contains multiple base % classifiers % Output: % weight: the weights for the base classifiers in the ensemble % M : set the size of classifiers % N : set the size of training data [M, N] = size(Prd) NumVar = M + N; % ---- Prepare QP ---- %equal constaint Aeq = [ones(1,M),zeros(1,NumVar - M)]; beq = 1; % lower and upper bounds lb = zeros(NumVar,1); % Ax <= b AP = Prd'; for id = 1:N AP(id, :) = AP(id, :) * Y(id); end AP = -1 .* AP; A = [ AP, -1 * eye(N) ]; b = -1 * ones(N ,1); % H H = zeros(NumVar,NumVar); H(1:M, 1:M) = Q; % f % lambda = 1.0; f = lambda * [zeros(M,1);ones(N, 1)]; % Optimization Using QP - MOSEK options = optimset('Display','off','TolFun', 1e-100); x0 = quadprog(H,f,A,b,Aeq,beq,lb,[],[],options); weight = x0(1:M); weight((weight <= 1e-6)) = 0; filename = sprintf('%sweight/weight_lambda_%d%s',dirname,lambda,'.csv') disp(filename) csvwrite(filename,weight);
簡単な実験
Phonemeデータセットに対して,5分割交差検証でRSEを使用しました.
SVM,KNN,Decision Tree,Bagged Decision Trees,Boosted Decision Trees,Boosted Decision Stumpsなどを様々なパラメータで767個のモデルをアンサンブルの候補としました.
このとき最もよかった候補モデルは,XGBoostではなくRandomForestで,mseが0.092937219730941698,f1が0.93484136670163342という値でした.
RSEは,mseが0.0859865470852018,f1が0.93981530924414192という結果だったので良くなっていますね.
best base model | RSE | |
---|---|---|
mse | 0.0929372197309416 | 0.93484136670163342 |
f1 | 0.0859865470852018 | 0.93981530924414192 |
詳細は,
感想など
グラフラプラシアンを使っている以上うまく適合するデータとそうでないデータはもちろんあると思います.
また,この手法では,アンサンブル候補モデルに対する結合重みを二次計画問題と捉え最適解を導出しますが,候補モデルが一律に似たような性能を示している場合,二次計画問題となり得ず,解くことができません.
ノルム制約を使って候補モデルの数をガンガン減らすのはとても理にかなっていると思ったので,回帰とかにも使えると嬉しいですね.