RuntimeError: expected device cuda:0 and dtype Byte but got device cuda:0 and dtype Bool
pytorch 运行的时候出现了下面的错误:
loading 9423 train samples...loading 1048 dev samples...Traceback (most recent call last): File "train.py", line 170, in main() File "train.py", line 166, in main train_net(args) File "train.py", line 73, in train_net logger=logger) File "train.py", line 115, in train pred, gold = model(padded_input, input_lengths, padded_target) File "/home/eric/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/media/data/sinovation_ventures/speechTransformer/transformer/transformer.py", line 35, in forward input_lengths) File "/home/eric/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/media/data/sinovation_ventures/speechTransformer/transformer/decoder.py", line 98, in forward slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)RuntimeError: expected device cuda:0 and dtype Byte but got device cuda:0 and dtype Bool
我的torch版本是1.2
解决方法
pip install torch==1.4
参考文献
[1].fix runtime error with pytorch 1.2.0. https://github.com/jadore801120/attention-is-all-you-need-pytorch/pull/115
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
暂时没有评论,来抢沙发吧~