matplotlib.animationを使ってみた

1年ちょっと前に、PythonとMatplotlibを使ってこんな感じのアニメーションを作りたくなったが、うまく動かせないまま、忙しくなって放置していた。

今年こそまたPythonのプログラミングを勉強を再開するのを新年の抱負として、この正月休みに再挑戦した。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

class Ball(object):
    def __init__(self, angle, speed):
        self.angle = angle
        self.speed = speed

    def __str__(self):
        return "%d, %d" % (self.angle, self.speed)

    def init_state(self):
        self.x = 0
        self.y = 0
        self.vx = self.speed * np.cos(self.angle * np.pi / 180)
        self.vy = self.speed * np.sin(self.angle * np.pi / 180)

    def step(self, dt):
        if self.y >= 0:
            self.x += self.vx * dt
            self.y += self.vy * dt
            self.vy -= 9.8 * dt

class Target(object):
    def __init__(self, x, y, score):
        self.x = x
        self.y = y
        self.score = score
    def __str__(self):
        return "(x,y)=(%d,%d) score=%d" % (self.angle, self.speed, self.score)
    def copy(self):
        return Target(self.x, self.y, self.score)

class ShootingDemo(object):
    dt = 0.05
    R = 1.0

    def __init__(self, balls_in, targets_in):
        self.balls = balls_in
        self.targets = targets_in

    def init_state(self):
        for b in self.balls:
            b.init_state()
        for t in self.targets:
            t.broken = False
        self.score = 0
        self.time = 0

    def step(self):
        for b in self.balls:
            if b.y >= 0:
                b.step(self.dt)
                for t in filter(lambda t: not t.broken, self.targets):
                    if (t.x - b.x) ** 2 + (t.y - b.y) ** 2 < self.R ** 2:
                        self.score += t.score
                        t.broken = True
        self.time += self.dt

    def finished(self):
        return np.all([b.y < 0 for b in self.balls])
        
class ArtistAnimationDemo(ShootingDemo):
    def __init__(self, balls_in, targets_in):
        super().__init__(balls_in, targets_in)

    def start(self):
        fig = plt.figure(figsize=(8, 4))
        axes = fig.gca()

        self.init_state()
        ims = []
        while not self.finished():
            self.step()
            im = []

            # Ball image
            for b in self.balls:
                if b.y >= 0:
                    circle = plt.Circle([b.x, b.y], radius=0.1, color='k')
                    im.append(axes.add_patch(circle))

            # Target image
            for t in self.targets:
                if not t.broken:
                    circle = plt.Circle([t.x, t.y], radius=self.R, color='r', fill=False)
                    im.append(axes.add_patch(circle))
                    im.append(axes.text(t.x, t.y, t.score,
                                        horizontalalignment='center',
                                        verticalalignment='center',
                                        fontsize=8))

            # Canon image
            im.append(axes.add_patch(plt.Rectangle([0, -0.7], 2, 1, 45)))

            # Set title
            # Unfortunately ArtistAnimation doesn't set title for each frame
            # im.append(axes.set_title('time={:.2f} score={}'.format(self.time, self.score)))
            # Instead
            im.append(axes.text(0.5, 1.01, 'time={:.2f} score={}'.format(self.time, self.score),
                                ha='center', va='bottom',
                                transform=axes.transAxes, fontsize='large'))
            ims.append(im)

        ani = animation.ArtistAnimation(fig, ims, interval=1000 * self.dt)
        plt.xlim(0, 50)
        plt.ylim(0, 25)
        plt.show()
        # ani.save("output.gif", writer="imagemagick")

class FuncAnimationDemo(ShootingDemo):
    def __init__(self, balls_in, targets_in):
        super().__init__(balls_in, targets_in)

    def start(self):
        def frame_gen():
            while not self.finished():
                self.step()
                yield self.time

        def plot(time):
            plt.cla()

            # Draw balls
            for b in self.balls:
                if b.y >= 0:
                    plt.gca().add_patch(plt.Circle([b.x, b.y], radius=0.1, color='k'))

            # Draw targets
            for t in self.targets:
                if not t.broken:
                    axes.add_patch(plt.Circle([t.x, t.y], radius=self.R, color='r', fill=False))
                    axes.text(t.x, t.y, t.score,
                              horizontalalignment='center',
                              verticalalignment='center',
                              fontsize=8)

            # Draw canon
            axes.add_patch(plt.Rectangle([0, -0.7], 2, 1, 45))

            # Set title
            plt.title('time={:.2f} score={}'.format(time, self.score))

            # Set range
            plt.xlim(0, 50)
            plt.ylim(0, 25)

        # body of start()
        fig = plt.figure(figsize=(8, 4))
        axes = fig.gca()
        ani = animation.FuncAnimation(fig, plot, interval=1000*self.dt, frames=frame_gen, init_func=self.init_state, repeat=True)
        # ani.save("output.gif", writer="imagemagick")
        plt.show()

if __name__ == '__main__':
    balls_sample = [Ball(angle, speed) for angle, speed in ((30, 25), (45, 20), (60, 15))]
    targets_sample = [Target(3*x + 10, 3*y + 3, 5*x + 10) for x in range(10) for y in range(3)]
    demo = ArtistAnimationDemo(balls_sample, targets_sample)
    demo.start()
    demo = FuncAnimationDemo(balls_sample, targets_sample)
    demo.start()

matplotlib.animationにはArtistAnimationとFuncAnimationという2種類のアニメーションの為のクラスがあるが、両方を試した。ArtistAnimationDemoがArtistAnimationを使ったもの、FuncAnimationDemoがFuncAnimationを使ったもので、同じものを同じスビードで描画しているつもりである。

以下、ArtistAnimationとFuncAnimationの両方を使ってみた上でのメモである。

  • FuncAnimationは遅い
  • Intel Core i5 1.6GHzのMacBook Airで試した限りであるが、この3秒分のアニメーションに、ArtistAnimationを使うと4秒、FuncAnimationを使うと9秒ほどかかる。
    フレーム毎に明示的にplt.cla()で消すのに時間がかかるようでもあるが、上記のコードだとplt.cla()しないとどんどん遅くなる。作り方を工夫すればどんどん遅くならないようにできるのかも知れないが、そもそも毎フレーム消さずに所望のアニメーションを実現できることは少ないだろう。少しずつ描画を進めるようなインクリルンタルなアニメーションでない限りは、速度を求めるならArtistAnimationを使う方が良さそうだ。
    ただ、最終的に動画ファイルを作成するのであれば、どちらを使っても関係ない。
  • ArtistAnimationはバグ解析しにくい
  • ArtistAnimationを使うとプログラムの構造は単純になるが、フレーム毎の処理を埋め込めないので、うまく行かない時の原因解析が難しい。対して、FuncAnimationはプログラムの構造に工夫が必要になるが、print文でフレーム毎の状態を表示できるので、問題解析は比較的容易である。
    当初、次のように、ボールのXY座標をリストに入れて、それを更新しながらそのままMatplotlibのAPIに渡すようなコードにしていたが、これだとボールが10フレーム目(最終フレーム)の位置から動かなかった。
    places = [[0, 0], # x, y of ball 1
              [0, 0], # x, y of ball 2
              [0, 0]] # x, y of ball 3
    
    for _ in range(10): # toward the 10th frame for testing
        for i in range(3):
            places[i][0] += velocity[i][0] # x += velocity(x)
            places[i][1] += velocity[i][1] # y += velocity(y)
            if places[i][1] >= 0:
                circle = plt.Circle(places[i], radius=0.1, color='k')
                im.append(axes.add_patch(circle))
    			...
    
    その原因が、参照渡しでplt.Circle()に渡したplaces[i]が即座に参照されず、リファレンスだけが保持され、ArtistAnimationの開始後に全て参照されるからで、places[i]をplaces[i].copy()またはplaces[i][:]に変えて複製を渡すようにすると正しく動くことに気付くのに、1年以上かかった。
    FuncAnimationを使っていれば、このような後で評価されるから起こる問題は起こらないし、起こってもすぐに原因を特定できただろう。
  • ArtistAnimationはフレーム毎にset_titleできない
  • 仕様なのかバグなのか不明だが、上記のコードでコメントアウトされている部分のように、フレーム毎にset_title()しようとしても、全フレームのtitleが最終フレームのtitleになってしまう。axes.set_titleをplt.titleにしても同様である。同様の報告はWeb上に複数あった。
    ArtistAnimationで何がフレーム毎にできないのかがわからないのは不便である。やってみないとわからないし、やってみてできなくても自分のコードの誤りかも知れないのである。
  • FuncAnimationはplt.show()した後にsave()すると正しく保存されない
  • 上記コードのようにしてGIFアニメを保存する場合しか確認していないし、環境依存のバグかも知れないが、plt.show()した後にsave()すると、フレーム毎の再描画がなされず、最後に描画したフレームが連続したファイルになってしまう。

何故こんなものを作ろうと思ったのかというと、機械学習の実験の為だったが、具体的にこれをどう使おうと思ったのかが思い出せない。
的の位置と点数を入力として、なるべく高得点を取れるような3球の初速度と角度を学習させようということだったと思うのだが、回帰問題なのでディープラーニングは使いにくいし、問題が複雑すぎて適当な回帰モデルの想像がつかない。何となく、的の位置と点数が与えられてから、シミュレーションで局所解を探索するしかないような気がする。
これを作ろうと思った当時は、何らか単純なモデルを想定した問題を考えていたと思うのだが...
的の位置を決まった数の格子点に限定して、入力はそれらの的の点数のみとし、球の初速度と角度のパターンを少なするなどして簡単な識別問題にしようとしたのだったか...