PyTorchのtorch.max()の使い方はどのようになりますか?
この記事では、PyTorchのtorch.max()関数の使い方について見ていきます。
予想通り、これは非常にシンプルな機能ですが、驚くほど多様な要素を持っています。
この関数を使って、いくつかの簡単な例を見てみましょう。
書いている時点では、使用されているPyTorchのバージョンはPyTorch 1.5.0です。
PyTorchのtorch.max()の基本的な構文については、以下のようになります。
PyTorchのtorch.max()を使用するためには、まずtorchをインポートします。
import torch
今、この関数はテンソル内の要素のうち最大値を返します。
PyTorchのtorch.max()のデフォルトの動作
デフォルトの動作は、グローバルな最大要素に対応する単一の要素とインデックスを返すことです。
max_element = torch.max(input_tensor)
以下は例です。 (Kore wa rei desu.)
p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)
出力
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor(2.7976)
実際に、これによってテンソル内の最大要素を取得できます。
次元を指定してtorch.max()を使用してください。
しかし、単一の要素ではなく、テンソルとして特定の次元に沿って最大値を取得することを希望する場合があります。
次元を指定するためのキーワード引数(numpyの軸)には、`dim`という別のオプションがあります。
これは最大限に取る方向を表しています。
これは、max_elementsとmax_indicesを含むタプルを返します。
- max_elements -> All the maximum elements of the Tensor.
- max_indices -> Indices corresponding to the maximum elements.
max_elements, max_indices = torch.max(input_tensor, dim)
これは、次元dimに沿って最大の要素を持つTensorを返します。
では、いくつかの例を見てみましょう。
p = torch.randn([2, 3])
print(p)
# Get the maximum along dim = 0 (axis = 0)
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)
出力
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])
見てわかる通り、私たちは次元0(列方向の最大値)に沿って最大値を見つけます。
また、要素に対応するインデックスを取得します。例えば、0.0688は0列のインデックス1を持ちます。
同様に、行ごとの最大値を見つけたい場合は dim=1 を使用します。
# Get the maximum along dim = 1 (axis = 1)
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)
出力
tensor([2.7976, 1.4443])
tensor([1, 2])
実際に、私たちは行方向において最大値の要素とそのインデックスを取得します。
比較のためにtorch.max()を使用します。
私たちは、torch.max()を使用して、2つのテンソル間の最大値を取得することもできます。
output_tensor = torch.max(a, b)
ここで、aとbは同じ次元を持っている必要があります。または、「broadcastable(ブロードキャスト可能)」なテンソルでなければなりません。
同じ次元を持つ2つのテンソルを比較する簡単な例を示します。
p = torch.randn([2, 3])
q = torch.randn([2, 3])
print("p =", p)
print("q =",q)
# Compare elements of p and q and get the maximum
max_elements = torch.max(p, q)
print(max_elements)
出力
p = tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
q = tensor([[-0.0678, 0.2042, 0.8254],
[-0.1530, 0.0581, -0.3694]])
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, 0.0581, 1.4443]])
確かに、pとqの間に最大の要素を持つ出力テンソルを得ます。
結論
この記事では、torch.max()関数を使用して、Tensorの最大要素を見つける方法について学びました。
私たちはこの機能を使って2つのテンソルを比較し、その中で最大値を取得しました。
類似の記事に関しては、ぜひ当社のPyTorchチュートリアルのコンテンツをご覧ください!さらにたくさんの情報をお楽しみに!
参考文献を日本語で書き換えてください。オプションは1つだけで十分です。
- PyTorch Official Documentation on torch.max()