深度学习生成舞蹈影片02之MDN代码练习
阅读难度:★★★★☆
技能要求:机器学习基础、Keras、numpy、matplotlib
字数:960字
阅读时长:5分钟
系列:
本文接上一期,补充一些MDN的代码练习。本教程开发环境是python+jupyter,引用了一个用keras写的mdn包,目标是拟合反正弦函数曲线:
y=7.0sin(0.85x)+0.5x+r
该函数在每个点都有多个解,因此要求ANN模型需要有能力处理它的损失函数。 MDN是预测这些多输出值的好方法。
1
引入相关依赖
import kerasimport mdnimport numpy as npimport matplotlib.pyplot as plt
2
生成模拟数据
#y=7.0sin(0.85x)+0.5x+r#r标准的高斯随机噪声NSAMPLE = 3000
y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE))
r_data = np.random.normal(size=NSAMPLE)
x_data = np.sin(0.85 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0
x_data = x_data.reshape((NSAMPLE, 1))plt.figure(figsize=(4, 4))
plt.plot(x_data,y_data,'ro', alpha=0.3,markersize = 1)
plt.show()
3
建模
接下来,我们在Keras中构建MDN模型。 使用了Keras中的Sequential模型,其中MDN层位于一个或多个Dense层之后。 您需要为MDN定义输出维度和混合状态的数量,比如:
MDN(output_dimension,number_mixtures)
对于本教程的问题,我们只需要定义输出维度为1,因为我们预测的y值维度为1。 添加更多的混合状态数量会增加更多参数(模型更复杂,需要更长时间训练),但可能有助于使预测结果更好。 你可以从训练数据中看到曲线评估混合状态的数量有5个,因此设置混合状态的数量N_MIXES = 5是比较好的方式。
对于MDN,我们需定义适合的损失函数,使其可以处理混合状态参数,损失函数必须考虑输出维数和混合状态的数量。
N_HIDDEN = 12
N_MIXES = 6
model = keras.Sequential()
model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))
model.add(keras.layers.Dense(N_HIDDEN, activation='relu'))
model.add(mdn.MDN(1, N_MIXES))
model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()
网络结果如下图所示:
4
训练模型
history = model.fit(x=x_data, y=y_data, batch_size=128, epochs=500, validation_split=0.2)
5
可视化
我们通过图表的方式查看模型是如何训练的。 从下图,我们可以看到,经过一定的训练后,训练效果的提升相当缓慢。对于本教程,1.5左右的损失值产生了相当好的结果。
代码如下:
plt.figure(figsize=(10, 5))plt.ylim([0,9])plt.plot(history.history['loss'])plt.plot(history.history['val_loss'])plt.show()
6
预测
现在我们可以通过在x轴上产生3000个均匀间隔点来预测y轴的数值,测试下训练好的模型。注意,y_test包含分布的参数,而不是图上的实际点。要在图上找到点,我们需要从每个分布中进行采样,采样后的结果为y_samples。
x_test = np.float32(np.arange(-15,15,0.01))
NTEST = x_test.size
print("Testing:", NTEST, "samples.")
x_test = x_test.reshape(NTEST,1) y_test = model.predict(x_test)
y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, 1, N_MIXES,temp=1.0)
对比下预测结果:
plt.figure(figsize=(4, 4))
plt.plot(x_data,y_data,'ro',x_test, y_samples[:,:,0], 'bo',alpha=0.3,markersize = 1)
plt.show()
附上keras实现的MDN:
https://github.com/cpmpercussion/keras-mdn-layer