LibTorch之网络模型构建

网友投稿 624 2022-11-08

LibTorch之网络模型构建

LibTorch之网络模型构建

LibTorch之网络模型构建

线性层

torch::nn::Linear ln1{nullptr};ln1 = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(24, 1)));

BatchNorm1d层

torch::nn::BatchNorm1d bn1{nullptr};bn1 = register_module("bn", torch::nn::BatchNorm1d(10));

官方示例

#include // Use one of many "standard library" modules.torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; // Define a new Module.struct Net : torch::nn::Module { Net() { // Construct and register two Linear submodules. fc1 = register_module("fc1", torch::nn::Linear(784, 64)); fc2 = register_module("fc2", torch::nn::Linear(64, 32)); fc3 = register_module("fc3", torch::nn::Linear(32, 10)); } // Implement the Net's algorithm. torch::Tensor forward(torch::Tensor x) { // Use one of many tensor manipulation functions. x = torch::relu(fc1->forward(x.reshape({x.size(0), 784}))); x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training()); x = torch::relu(fc2->forward(x)); x = torch::log_softmax(fc3->forward(x), /*dim=*/1); return x; }};int main() { // Create a new Net. auto net = std::make_shared(); // Create a multi-threaded data loader for the MNIST dataset. auto data_loader = torch::data::make_data_loader( torch::data::datasets::MNIST("./data").map( torch::data::transforms::Stack<>()), /*batch_size=*/64); // Instantiate an SGD optimization algorithm to update our Net's parameters. torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01); for (size_t epoch = 1; epoch <= 10; ++epoch) { size_t batch_index = 0; // Iterate the data loader to yield batches from the dataset. for (auto& batch : *data_loader) { // Reset gradients. optimizer.zero_grad(); // Execute the model on the input data. torch::Tensor prediction = net->forward(batch.data); // Compute a loss value to judge the prediction of our model. torch::Tensor loss = torch::nll_loss(prediction, batch.target); // Compute gradients of the loss w.r.t. the parameters of our model. loss.backward(); // Update the parameters based on the calculated gradients. optimizer.step(); // Output the loss and checkpoint every 100 batches. if (++batch_index % 100 == 0) { std::cout << "Epoch: " << epoch << " | Batch: " << batch_index << " | Loss: " << loss.item() << std::endl; // Serialize your model periodically as a checkpoint. torch::save(net, "net.pt"); } } }}

#include // 定义模型struct Net :torch::nn::Module { Net() { // 构造和注册两个线性子模块 fc1 = register_module("fc1",torch::nn::Linear(784,64)); fc2 = register_module("fc2",torch::nn::Linear(64,32)); fc3 = register_module("fc3",torch::nn::Linear(32,10)); } // 实现网络的前向传播 torch::Tensor forward() { }};int main(int argc, char** argv) { return 0;}

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:对Mybatis Plus中@TableField的使用正解
下一篇:LibTorch之激活函数层
相关文章

 发表评论

暂时没有评论,来抢沙发吧~