网络技术爱好者的栖息之地,让我们的技术更上一层楼!

keras自定义网络层

QQX7网 热点趣闻

在深度学习领域,Keras是一个高度封装的库并被广泛应用,可以通过调用其内置网络模块(各种网络层)实现针对性的模型结构;当所需要的网络层功能不被包含时,则需要通过自定义网络层或模型实现。

如何在keras框架下自定义层,基本“套路”如下。

一般地,keras中的网络层是一个类,所以自定义层即编写一个类,更为重要的是这个类(即自定义层)需要继承Layer父类,而且需要实现以下四种方法:

  1. __init __ (self, output_dim, **kwargs)

这个方法是用来初始化并自定义自定义层所需的属性,比如output_dim;
此外,该方法需要执行super().__init __(**kwargs),这行代码是执行Layer类中的初始化函数;
当执行上述代码就没有必要去管input_shape,weights,trainable等关键字参数,因为父类(Layer)的初始化函数实现了它们与layer实例的绑定。

  1. build(self, input_shape)

这个方法是用来创建层的权重;
在该方法中,根据之前的继承,通过Layer类的add_weight方法来自定义并添加一个权重矩阵,这个方法需要input_shape参数;
该方法必须设self.built = True,目的是为了保证这个层的权重定义函数build被执行过了;
在built函数中,需要说明这个权重各方面的属性,比如shape、初始化方式以及可训练性等信息。

  1. call(self, x)

这个方法是用来编写层的功能逻辑;
在该方法中,需要关注传入call的第一个参数:输入张量x;x只能是一种形式变量,不能是具体的变量,即它不能被定义;
这个call函数就是该层的计算逻辑,当创建好这个层实例后,该实例可以执行call函数;
可见,这个层的核心应该是一段符号式的输入张量到输出张量的计算过程。

  1. compute_output_shape(self, input_shape)

这个方法是用来保证输出shape是正确的;
这里重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出的shape符合实际;
父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,所以需要重写该方法。

示例

结合官方文档的例子,给出如下一个自定义层的代码:

使用自定义层,就如同使用keras内置网络层一样,如下图所示:(另外,本例使用kears内置的激活函数层ReLU承接自定义层的输出,从而避免将激活函数的功能加入到自定义层中)

Community Cloud零基础学习(五)Topic(主题)管理,Service Cloud 零基础(三)Knowledge浅谈

标签: 暂无标签

免责声明:

本站提供的资源,都来自网络,版权争议与本站无关,所有内容及软件的文章仅限用于学习和研究目的。不得将上述内容用于商业或者非法用途,否则,一切后果请用户自负,我们不保证内容的长久可用性,通过使用本站内容随之而来的风险与本站无关,您必须在下载后的24个小时之内,从您的电脑/手机中彻底删除上述内容。如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。侵删请致信E-mail:114077@qq.com

广告位
同类推荐
评论列表

最新评论
推荐内容
热门文章
随机推荐
资源标签
热点趣闻 keras自定义网络层
在深度学习领域,Keras是一个高度封装的库并被广泛应用,可以通过调用其内置网络模块(各种网络层)实现针对性的模型结构;当所需要的网络层功能不被包含时,则需要通...
扫描二维码阅读原文
QQX7网 January, 01
生成社交图 ×