From e8cd4fcbb293226b903477087323bf5ea6f0e5d8 Mon Sep 17 00:00:00 2001 From: jcmgray Date: Tue, 16 Oct 2018 22:18:48 +0100 Subject: [PATCH] torch: use tensordot if available --- opt_einsum/backends/torch.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index 90edec1..47e2216 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -12,17 +12,20 @@ __all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"] _TORCH_DEVICE = None +_TORCH_HAS_TENSORDOT = None _torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' def _get_torch_and_device(): global _TORCH_DEVICE + global _TORCH_HAS_TENSORDOT if _TORCH_DEVICE is None: import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' _TORCH_DEVICE = torch, device + _TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot') return _TORCH_DEVICE @@ -47,7 +50,11 @@ def einsum(equation, *operands): def tensordot(x, y, axes=2): """Simple translation of tensordot syntax to einsum. """ - # XXX: tensordot should be directly implemented in torch soon + torch, _ = _get_torch_and_device() + + if _TORCH_HAS_TENSORDOT: + return torch.tensordot(x, y, dims=axes) + xnd = x.ndimension() ynd = y.ndimension()