How do we use nn.parameter in PyTorch?

In PyTorch, nn.Parameter is a special type of Tensor that represents trainable parameters in nn.Module. These nn.Parameter objects are automatically identified and registered as model’s trainable parameters by the nn.Module constructor.

To use nn.Parameter in PyTorch, you first need to create an nn.Parameter object and assign it as an attribute of your model. Here is a simple example:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.weight = nn.Parameter(torch.rand(3, 4))  # 创建一个参数

    def forward(self, x):
        out = torch.matmul(x, self.weight)
        return out

model = MyModel()
print(model.weight)  # 打印参数

In the example above, we defined a class called MyModel that inherits from nn.Module. In the constructor __init__, we created an nn.Parameter object self.weight, which is a randomly initialized Tensor with shape (3, 4).

In the forward method, we can calculate using the self.weight parameter. Once the model is created, we can access this parameter through model.weight.

It is important to note that nn.Parameter objects are automatically registered as trainable parameters of the model, and can be accessed in the model’s parameters() method. Additionally, nn.Parameter objects also automatically have the ability to compute gradients, which can be calculated automatically using the backward() method.

Leave a Reply 0

Your email address will not be published. Required fields are marked *


广告
Closing in 10 seconds
bannerAds