-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlenet5.h
More file actions
25 lines (20 loc) · 752 Bytes
/
lenet5.h
File metadata and controls
25 lines (20 loc) · 752 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#pragma once
#include <torch/torch.h>
/// @brief Implement of LeNet5,
/// according to LeCun Y, Bottou L, Bengio Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11): 2278-2324.
class LeNet5Impl : public torch::nn::Module
{
public:
/// @brief Constructor
/// @param input_size the width of image, padding to 32 * 32
LeNet5Impl(int input_size);
/// @brief forward function
/// @param x input tensor
/// @return output tensor
torch::Tensor forward(torch::Tensor x);
private:
/// Pooling layer S2 and S4 used in forward step
torch::nn::Conv2d C1{nullptr}, C3{nullptr}, C5{nullptr};
torch::nn::Linear F6{nullptr}, OUTPUT{nullptr};
};
TORCH_MODULE(LeNet5);