如何提升企业数字化转型的效率与灵活性
2066
2022-11-17
python的jax包的常用操作
python的jax包的常用操作
本文参考官方文档
1.jax.random包
PRNGKey
>>> from jax import random>>> key = random.PRNGKey(0)>>> keyDeviceArray([0, 0], dtype=uint32)
根据传入参数,生成两个无符号32位整数(不用管具体的细节,理解后面的使用即可),我们通常其称为key
>>> random.uniform(key)DeviceArray(0.41845703, dtype=float32)
key可以用于任何jax的随机数生成
>>> random.uniform(key)DeviceArray(0.41845703, dtype=float32)
如果key不变,结果不变 如果你需要新的随机数,你可以使用jax.random.split()
>>> key, subkey = random.split(key)>>> random.uniform(subkey)DeviceArray(0.10536897, dtype=float32)
jax.random.gumbel
Gumbel 分布及应用浅析
#指定形状和数据类型的gumbel采样,key为刚才的key,返回的是个数组类型jax.random.gumbel(key, shape=(), dtype=
2.jax.experimental.stax包
jax.experimental.stax是一个小型但灵活的神经网络规范库,可以快速的生成指定的网络层
from jax.experimental import staxlayers=[]#存放生成的网络层
#生成out_dim个神经元的全连接层stax.Dense(out_dim, W_init=
stax.serial
#返回将网络层组合起来的结果,是个(init_fun, apply_fun) pair,表示给定层序列的连续组成。layers是刚才定义的list,里面放刚才定义的层即可stax.serial(*layers)#我们甚至可以用新的list来接受stax.serial(*layers)返回值进行套娃layers=[stax.serial(*layers)]layers.append(stax.Dense(16))layers.append(stax.Relu)
3.jax.numpy包
这个包方法和numpy类似
#创建一个数组jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0)#对x进行深度为k的one_hot编码jax.numpy.eye(k)[x]#返回将对数组进行排序的索引,a为数组,axis为指定维度jax.numpy.argsort(a, axis=- 1, kind='quicksort', order=None)#计算指定维度的均值jax.numpy.mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=None)#计算指定数组维度的和jax.numpy.sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)#重复一个数组的元素jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~