分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 运营维护

MXNet中bucket机制注记

发布时间:2023-09-06 01:31责任编辑:董明明关键词:暂无标签

Preface

之前看API以为bucket是一个根植于底层操作的接口(MXNet doc功不可没 -_-|| )。从LSTM看过来,接触到了一些相关的程序,后面再把bucketing_module.py那部分查看了下,发现bucket只是一个应用层机制,主要的实现存在于module/bucketing_module.py里面。原理清晰,实现简洁,在这做个记号。

Code & Comments

先放些相关的链接,做个预备。

  1. MXNet 官方的文档(\tucao 出个文档真不容易,还带时效性...)
  2. 大神的blog阐述,鞭辟入里
  3. 之前关于LSTM的blog
    鉴于大神已经在这篇[blog]里面说得生动透彻了,这里就能省就省,然后说些大神没功夫顾及的细节。
    另外考虑到MXNet的链接经常表现出不靠谱的症状(\kuxia),归结一下1中有些用的结论:要使用bucket机制,初始化Module时传入的symbol应该是一个函数,这个函数在被调用时将被传入迭代器中的bucket_key参数

从调用路径的顺序来走一遍把。
fit里面经过bind,init等操作,后面会调用prepare对预取出的数据(如果有)进行准备:

# module/bucketing_module.py ???def prepare(self, data_batch): ???????"""Prepares a data batch for forward. ???????Parameters ???????---------- ???????data_batch : DataBatch ???????""" ???????# perform bind if haven‘t done so ???????assert self.binded and self.params_initialized ???????bucket_key = data_batch.bucket_key ???????original_bucket_key = self._curr_bucket_key ???????data_shapes = data_batch.provide_data ???????label_shapes = data_batch.provide_label ???????self.switch_bucket(bucket_key, data_shapes, label_shapes) ???????# switch back ???????self.switch_bucket(original_bucket_key, None, None)

显然,switch_bucket就是负责进行重新绑定的:

# module/bucketing_module.py ???def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): ????????assert self.binded, ‘call bind before switching bucket‘ ???????if not bucket_key in self._buckets: ???# check if there is already... ???????????symbol, data_names, label_names = self._sym_gen(bucket_key) ???????????module = Module(symbol, data_names, label_names, ???????????????????????????logger=self.logger, context=self._context, ???????????????????????????work_load_list=self._work_load_list, ???????????????????????????fixed_param_names=self._fixed_param_names, ???????????????????????????state_names=self._state_names) ???????????module.bind(data_shapes, label_shapes, self._curr_module.for_training, ???????????????????????self._curr_module.inputs_need_grad, ???????????????????????force_rebind=False, shared_module=self._buckets[self._default_bucket_key]) ???????????self._buckets[bucket_key] = module ???????self._curr_module = self._buckets[bucket_key] ???????self._curr_bucket_key = bucket_key

逻辑很明白,_curr_module里面放了众多的module,这些module的参数全都指向同一组。如果出入的bucket_key没有出现过,就bind一个并放入*_curr_module列表里面去;如果已经有了(包括刚刚bind出来的),就切换到那个module*上。

Misc

其他有一些相关的材料顺带放在这。

  1. 上一篇blog里面推测bucket机制可能会对补齐的那部分进行处理,这一点与io.py里面的DataBatchpad变量有些联系。在module/base_module.py中,查找pad的引用,发现和io.py里面的注释一致,只在prediction的时候进行了使用,训练的时候被忽视。
  2. exmple/rnn/bucketing里面有更高层接口的使用示例。

MXNet中bucket机制注记

原文地址:http://www.cnblogs.com/chenyliang/p/8060014.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved