LibTorch之损失函数
LibTorch之损失函数
torch::Tensor loss = torch::nll_loss(prediction, batch.target);
官方案例使用
#include // Use one of many "standard library" modules.torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; // Define a new Module.struct Net : torch::nn::Module { Net() { // Construct and register two Linear submodules. fc1 = register_module("fc1", torch::nn::Linear(784, 64)); fc2 = register_module("fc2", torch::nn::Linear(64, 32)); fc3 = register_module("fc3", torch::nn::Linear(32, 10)); } // Implement the Net's algorithm. torch::Tensor forward(torch::Tensor x) { // Use one of many tensor manipulation functions. x = torch::relu(fc1->forward(x.reshape({x.size(0), 784}))); x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training()); x = torch::relu(fc2->forward(x)); x = torch::log_softmax(fc3->forward(x), /*dim=*/1); return x; }};int main() { // Create a new Net. auto net = std::make_shared(); // Create a multi-threaded data loader for the MNIST dataset. auto data_loader = torch::data::make_data_loader( torch::data::datasets::MNIST("./data").map( torch::data::transforms::Stack<>()), /*batch_size=*/64); // Instantiate an SGD optimization algorithm to update our Net's parameters. torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01); for (size_t epoch = 1; epoch <= 10; ++epoch) { size_t batch_index = 0; // Iterate the data loader to yield batches from the dataset. for (auto& batch : *data_loader) { // Reset gradients. optimizer.zero_grad(); // Execute the model on the input data. torch::Tensor prediction = net->forward(batch.data); // Compute a loss value to judge the prediction of our model. torch::Tensor loss = torch::nll_loss(prediction, batch.target); // Compute gradients of the loss w.r.t. the parameters of our model. loss.backward(); // Update the parameters based on the calculated gradients. optimizer.step(); // Output the loss and checkpoint every 100 batches. if (++batch_index % 100 == 0) { std::cout << "Epoch: " << epoch << " | Batch: " << batch_index << " | Loss: " << loss.item() << std::endl; // Serialize your model periodically as a checkpoint. torch::save(net, "net.pt"); } } }}
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
暂时没有评论,来抢沙发吧~