ほぼPython

Not技術ブログBut勉強ブログ 内容には誤りがあることが多いです

Pythonで2変数関数の勾配を可視化する

前回までではじめてのPython企画(Pythonに慣れよう!)は終了して、ここからついに、目標だった機械学習の勉強を進めていきます。


ただ、僕はどちらかというと学問的なところに興味があるので、機械学習のライブラリを使って面白いものを作る!とかではなく、学問的な内容が多くなると思います。

今回は機械学習のための基礎として二変数関数の勾配をグラフとして表すプログラムを書こうと思います。

勾配とはなにか

勾配とは多変数関数の偏微分を並べたベクトルのことで、ある点における傾き具合(傾きの方向と傾きの強さ)を表しているようなイメージだと思います。

例えば、階段で、上にあがっていくのはキツイですが、横に動くのは簡単ですよね。
このように多次元の空間では一地点でも複数の傾きを持っているわけですね。それをまとめたもののことを勾配と呼んでいるのだと思います。(たぶん)

勾配がなぜ機械学習に使えるのか

後の勾配のグラフを見るとわかると思いますが、勾配をたどっていくと、いまの地点よりも高い位置に行くことができ、勾配を逆からたどっていくと、今の地点よりも低い位置に行くことができます。

ところで、機械学習の問題は最終的には、関数の最大値または最小値を求める問題に帰着するみたいです。

ということは、あるデータに対して関数化し、計算するわけですが、データをいい感じに関数化する際に、誤差という問題が出てきます。その誤差を最小にするために誤差からまた関数を作り、その値が最小となる点を求めるのに上記の勾配の性質が利用されているようです。(たぶん)

偏微分係数を計算するプログラム

実際に勾配を可視化するプログラムを書いていきたいと思いますが、まずはその前に偏微分係数を計算するプログラムを書きます。
今回使う関数はこちら!
{\displaystyle }f(w_0,w_1)=w_0^2 + 2w_0w_1 + 3

まずはモジュールのインポートと、関数と偏微分した関数の宣言をします。

import numpy as np
import matplotlib.pyplot as plt

def f(w0,w1):
    return w0**2 + 2*w0*w1 + 3 #今回使う関数

def df_dw0(w0,w1):
    return 2*w0 + 2*w1 #w0で偏微分

def df_dw1(w0,w1):
    return 2*w0 #w1で偏微分

次に数学でいう定義域の決定的なことをします。

w_range = 2 #範囲
dw = 0.25 #範囲を区切る間隔
w0 =np.arange(-w_range,w_range,dw) #-2から2未満まで0.25の間隔で配列(ndarray)を作る
w1 =np.arange(-w_range,w_range,dw)
wn = w0.shape[0] #そのベクトルのサイズ

実際に偏微分していきます。

dff_dw0 = np.zeros((len(w0),len(w1))) #各点のw0での偏微分係数を入れるための配列
dff_dw1 = np.zeros((len(w0),len(w1))) #w1での(以下略)

for i0 in range(wn):
    for i1 in range(wn):
        dff_dw0[i1,i0]=df_dw0(w0[i0],w1[i1]) #各点におけるw0での偏微分係数
        dff_dw1[i1,i0]=df_dw1(w0[i0],w1[i1]) #w1での(以下略)

試しにdff_dw0[1][2]の値をprintしてみると-6.5と表示されました。
これはfのw0=-1.75,w1=-1.5の点におけるw0による偏微分係数と一致していますね。(手計算しました)

可視化するプログラム

基本的にはこれだけで大丈夫です。

plt.quiver(w0,w1,dff_dw0,dff_dw1)
plt.show()

しかし、軸とかがいまいちなので体裁を整えます。
plt.show()の前に以下を追加します。

plt.xlabel('$w_0$',fontsize=14) #x軸のラベル
plt.ylabel('$w_1$',fontsize=14) #y軸のラベル
plt.xticks(range(-w_range,w_range+1,1)) #x軸に表示する値
plt.yticks(range(-w_range,w_range+1,1)) #y軸に表示する値
plt.xlim(-w_range-0.5,w_range+0.5) #x軸の範囲
plt.ylim(-w_range-0.5,w_range+0.5) #y軸の範囲

f:id:short_4010:20180226125630p:plain
こんなグラフが得られます。矢印はグラフの斜面の高い方へ向いており、矢印が長いほど、その斜面が急であることを示しています。

ちなみにグラフを描くとこんな感じになります。勾配のグラフと見比べると面白いですね。
f:id:short_4010:20180226133502p:plain


追記 なんか機種によってはプログラムのインデントが変に表示されるかもしれませんので注意してください