learn2learn:针对研究人员的PyTorch元学习框架

网友投稿 1515 2022-10-27

learn2learn:针对研究人员的PyTorch元学习框架

learn2learn:针对研究人员的PyTorch元学习框架

learn2learn is a PyTorch library for meta-learning implementations.

The goal of meta-learning is to enable agents to learn how to learn. That is, we would like our agents to become better learners as they solve more and more tasks. For example, the animation below shows an agent that learns to run after a only one parameter update.

Features

learn2learn provides high- and low-level utilities for meta-learning. The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms. The low-level utilities enable researchers to develop new and better meta-learning algorithms.

Some features of learn2learn include:

Modular API: implement your own training loops with our low-level utilities.Provides various meta-learning algorithms (e.g. MAML, FOMAML, MetaSGD, ProtoNets, DiCE)Task generator with unified API, compatible with torchvision, torchtext, torchaudio, and cherry.Provides standardized meta-learning tasks for vision (Omniglot, mini-ImageNet), reinforcement learning (Particles, Mujoco), and even text (news classification).100% compatible with PyTorch -- use your own modules, datasets, or libraries!

Installation

pip install learn2learn

API Demo

The following is an example of using the high-level MAML implementation on MNIST. For more algorithms and lower-level utilities, please refer to the documentation or the examples.

import learn2learn as l2lmnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)mnist = l2l.data.MetaDataset(mnist)train_tasks = l2l.data.TaskDataset(mnist, task_transforms=[ NWays(mnist, n=3), KShots(mnist, k=1), LoadData(mnist), ], num_tasks=10)model = Net()maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)opt = optim.Adam(maml.parameters(), lr=4e-3)for iteration in range(num_iterations): learner = maml.clone() # Creates a clone of model for task in train_tasks: # Split task in adaptation_task and evalutation_task # Fast adapt for step in range(adaptation_steps): error = compute_loss(adaptation_task) learner.adapt(error) # Compute evaluation loss evaluation_error = compute_loss(evaluation_task) # Meta-update the model parameters opt.zero_grad() evaluation_error.backward() opt.step()

Changelog

A human-readable changelog is available in the CHANGELOG.md file.

Documentation

Documentation and tutorials are available on learn2learn’s website: http://learn2learn-.

Citation

To cite the learn2learn repository in your academic publications, please use the following reference.

Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner. "learn2learn". https://github.com/learnables/learn2learn, 2019.

You can also use the following Bibtex entry.

@misc{learn2learn2019, author = {Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner}, title = {learn2learn}, month = sep, year = 2019, url = {https://github.com/learnables/learn2learn} }

Acknowledgements & Friends

The RL environments are adapted from Tristan Deleu's implementations and from the ProMP repository. Both shared with permission, under the MIT License.TorchMeta is similar library, with a focus on supervised meta-learning. If learn2learn were missing a particular functionality, we would go check if TorchMeta has it. But we would also open an issue ;)higher is a PyTorch library that also enables differentiating through optimization inner-loops. Their approach is different from learn2learn in that they monkey-patch nn.Module to be stateless. For more information, refer to their ArXiv paper.

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:spring中IOC控制反转依赖注入和new对象的区别说明
下一篇:Min-Admin 基于dva框架+antd的React后台模板
相关文章

 发表评论

暂时没有评论,来抢沙发吧~