0 Pluspunkte 0 Minuspunkte

Wie kann ich 2 pyTorch Tensoren vergleichen?

a = torch.tensor([1, 2, 3, 4]); 
b = torch.tensor([1, 2, 3, 4]);

if torch.eq(a, b):
    print("Beide Tensoren sind gleich")
von  

1 Antwort

0 Pluspunkte 0 Minuspunkte

Du versuchst die Methode torch.eq() direkt auf den Tensoren a und b aufzurufen, was zu einer mehrdeutigen Auswertung (für jeden Index) führt. Du kannst die torch.eq() Funktion mit der torch.all() Funktion kombinieren, um sicherzustellen, dass alle Elemente der Tensoren gleich sind.

import torch

a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([1, 2, 3, 4])

if torch.all(torch.eq(a, b)):
    print("Beide Tensoren sind gleich")

von