オオハタの研究ノート

考えたこととか勉強したこととか、書いていきます。

SketchRNN を作った

SketchRNN を作った

この記事は群馬大学電子計算機研究会 IGGG Advent Calendar 2019 - Adventar 15 日目の記事です。一週間遅れの.

B3 であるということもあり,研究室ではあまり成果を要求されることはないので ゆるりとやっており,その一環として SketchRNN を作りました.

結果はこんな感じです.

f:id:KenjiOhata:20191222192536p:plain

モデル概要

SketchRNN は Google の Magenta プロジェクトの一環として作られたスケッチ生成のための モデルです.

中身の基本はは seq2seq の VAE になっています.筆の動きを表す,ストロークデータを 入力することで,ストロークを再現するように学習します.

左が入力で右が出力になります.

f:id:KenjiOhata:20191222192539p:plain

SketchRNN では ニューラルネットワークの出力はストロークデータではなく, ストロークデータを生成するような確率分布のパラメータを出力します. これにより,ランダム性を残しながらスケッチ生成ができます.

詳細は論文を参照ください. https://arxiv.org/abs/1704.03477

Julia -> Pytorch

SketchRNN の作成に取り掛かったのは 11 月中旬くらいでした. 最初の段階では Julia と Flux.jl で実装することを考えていました.

Julia 採用理由として

Python のようにライブラリを通して行列演算をすることの煩わしさがないことや, exp などから標準であること,Unicode の利用で コードの見た目がかなり簡潔になることか挙げられます.

また,Julia 自体にも興味があったので速度の面や実用さを知りたかったというのもあります.

結果として,Pytorch で書き直すことになりました.

自分が書いたコードが GPU で動かなかったからです. 公式サンプルは動きますが,僕の書いたコードはエラーが残ってしまいました. 追求すれば解決できるかもしれませんでしたが,GPU で動かすことに時間を費やすことは SketchRNN 実装とは違うので本筋を見失わないうちに Pytorch に移行しました.

ただ,Julia はマクロ機能や Unicode などほんとにいい言語だと思いました. Flux.jl もまだバージョン 1 にもなっていない(2019/12 現在)のでまだ,成長を待つ段階 なのだと思いました.

オンプレ機械学習いいな

初めてオンプレでまともに機械学習をしました. 今までは Google Colab を使用して来ましたが,時間制約のおかげで, 12 時間毎に操作が要求されたり,その間のコンピュータの使用が制限されるなど ろくなものではなかったので,オンプレ最高でした.研究室に感謝です.

今回作成した SketchRNN-Pytorch のレポジトリはこちらにおいておきます.

github.com