Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch backend: use tensordot if available #66

Merged
merged 1 commit into from
Oct 16, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion opt_einsum/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
__all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"]

_TORCH_DEVICE = None
_TORCH_HAS_TENSORDOT = None

_torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


def _get_torch_and_device():
global _TORCH_DEVICE
global _TORCH_HAS_TENSORDOT

if _TORCH_DEVICE is None:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
_TORCH_DEVICE = torch, device
_TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')

return _TORCH_DEVICE

Expand All @@ -47,7 +50,11 @@ def einsum(equation, *operands):
def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum.
"""
# XXX: tensordot should be directly implemented in torch soon
torch, _ = _get_torch_and_device()

if _TORCH_HAS_TENSORDOT:
return torch.tensordot(x, y, dims=axes)

xnd = x.ndimension()
ynd = y.ndimension()

Expand Down