Skip to content

Commit

Permalink
Merge pull request #3 from Orcuslc/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Orcuslc authored Apr 19, 2018
2 parents a72be22 + b8acfc7 commit 66377a7
Show file tree
Hide file tree
Showing 26 changed files with 570 additions and 613 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ test/
dist/
MANIFEST
old_orthnet/
orthnet/utils/_enum_dim.cpython*
build/
orthnet/utils/_enum_dim.*
orthnet/utils/enum_dim/enum_dim.py
orthnet/utils/enum_dim/enum_dim_wrap.cpp
66 changes: 50 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,54 @@ python3 setup.py build_ext --inplace && python3 setup.py install
- orthnet.Jacobi(Poly)

## Base class:
Class `Poly(module, degree, x, dtype = 'float32', loglevel = 0)`:

Class `Poly(x, degree, combination = None)`:
- Inputs:
+ module: one of {'tensorflow', 'pytorch', 'numpy'}
+ degree: the highest degree of target polynomial
+ x: input tensor of type {tf.placaholder, tf.Variable, torch.Variable, torch.Tensor, numpy.ndarray, numpy.matrix}
+ dtype: 'float32' or 'float64'
+ loglevel: 1 to print time cost and 0 to mute

- `Poly.tensor`: return a tensor of function values
- `Poly.combination`: return the combination of dimensions, in lexicographical order
- `Poly.index`: return the index of the first combination of each degree in `self.combination`
- `Poly.update(degree)`: update the degree of polynomial
- `Poly.get_combination(start, end):`: return the combination of degrees from `start`(included) till `end`(included)
- `Poly.get_poly(start, end)`: return the polynomials of degrees from `start`(included) till `end`(included)
- `Poly.eval(coefficients)`: evaluate the value of polynomial with coefficients
- `Poly.quadrature(func, weight)`: evaluate Gauss quadrature with target function and weights
+ `x`: a tensor
+ `degree`: highest degree for target polynomials
+ `combination`: optional, (if the combinations of some degree and dim is computed by `orthnet.enum_dim(degree, dim)`, then one may pass the combinations to save computing time).
- Attributes:
+ `Poly.tensor` the tensor of function values
+ `Poly.length` the number of function basis (columns) in `Poly.tensor`
+ `Poly.index` the index of the first combination of each degree in `Poly.combinations`
+ `Poly.combinations` all combinations of tensor product
+ `Poly.tensor_by_degree(degree)` all polynomials of some degrees
+ `Poly.eval(coefficients)` eval the function values with given coefficients
+ `Poly.quadrature` perform Gauss quadrature with given function and weight

## Examples:

### with TensorFlow
```python
import tensorflow as tf
import numpy as np
from orthnet import Legendre

x_data = np.random.random((10, 2))
x = tf.placeholder(dtype = tf.float32, shape = [None, 2])
L = Legendre(x, 5)

with tf.Session() as sess:
print(L.tensor, feed_dict = {x: x_data})
```

### with PyTorch
```python
import torch
import numpy as np
from orthnet import Legendre

x = torch.DoubleTensor(np.random.random((10, 2)))
L = Legendre(x, 5)
print(L.tensor)
```

### with Numpy
```python
import numpy as np
from orthnet import Legendre

x = np.random.random((10, 2))
L = Legendre(x, 2)
print(L.tensor)
```

2 changes: 0 additions & 2 deletions demo/legendre.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import sys
sys.path.append('../')
from orthnet import Legendre
import tensorflow as tf
import torch
Expand Down
3 changes: 2 additions & 1 deletion orthnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .poly import *
from .utils import enum_dim

__all__ = ['Legendre', 'Legendre_Normalized', 'Laguerre', 'Hermite', 'Hermite2', 'Chebyshev', 'Chebyshev2', 'Jacobi']
__all__ = ['Legendre', 'Legendre_Normalized', 'Laguerre', 'Hermite', 'Hermite2', 'Chebyshev', 'Chebyshev2', 'Jacobi', 'enum_dim']
5 changes: 5 additions & 0 deletions orthnet/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._tensorflow import TensorflowBackend
from ._torch import TorchBackend
from ._numpy import NumpyBackend

__all__ = ["TensorflowBackend", "TorchBackend", "NumpyBackend"]
41 changes: 41 additions & 0 deletions orthnet/backend/_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from functools import wraps

def assert_backend_available(f):
@wraps(f)
def check(backend, *args, **kw):
if not backend.is_available:
raise RuntimeError(
"Backend `{1}` is not available".format(str(backend)))
return f(backend, *args, **kw)
return check


class Backend(object):
def __str__(self):
return "<backend>"

def __false(self):
return False

is_available = is_compatible = __false

def concatenate(self, tensor, axis):
return None

def ones_like(self, tensor):
return None

def multiply(self, x, y):
return None

def expand_dims(self, tensor, axis):
return None

def get_dims(self, tensor):
return None

def reshape(self, tensor, shape):
return None

def matmul(self, tensor1, tensor2):
return None
49 changes: 49 additions & 0 deletions orthnet/backend/_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
numpy backend
"""
try:
import numpy as np
except ImportError:
np = None

from ._backend import Backend, assert_backend_available


class NumpyBackend(Backend):

def __str__(self):
return "numpy"

def is_available(self):
return np is not None

@assert_backend_available
def is_compatible(self, args):
if list(filter(lambda t: isinstance(args, t), [
np.ndarray,
np.matrix
])) != []:
return True
# , "numpy backend requires input to be an instance of `np.ndarray` or `np.matrix`"
return False

def concatenate(self, tensor, axis):
return np.concatenate(tensor, axis = axis)

def ones_like(self, tensor):
return np.ones_like(tensor)

def multiply(self, x, y):
return x*y

def expand_dims(self, tensor, axis):
return np.expand_dims(tensor, axis)

def get_dims(self, tensor):
return tensor.shape

def reshape(self, tensor, shape):
return np.reshape(tensor, shape)

def matmul(self, tensor1, tensor2):
return np.dot(tensor1, tensor2)
49 changes: 49 additions & 0 deletions orthnet/backend/_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
tensorflow backend
"""
try:
import tensorflow as tf
except ImportError:
tf = None

from ._backend import Backend, assert_backend_available


class TensorflowBackend(Backend):

def __str__(self):
return "tensorflow"

def is_available(self):
return tf is not None

@assert_backend_available
def is_compatible(self, args):
if list(filter(lambda t: isinstance(args, t), [
tf.Tensor,
tf.Variable
])) != []:
return True
# "tensorflow backend requires input to be an isinstance of `tensorflow.Tensor` or `tensorflow.Variable`"
return False

def concatenate(self, tensor, axis):
return tf.concat(tensor, axis = axis)

def ones_like(self, tensor):
return tf.ones_like(tensor)

def multiply(self, x, y):
return tf.multiply(x, y)

def expand_dims(self, tensor, axis):
return tf.expand_dims(tensor, axis)

def get_dims(self, tensor):
return [dim.value for dim in tensor.get_shape()]

def reshape(self, tensor, shape):
return tf.reshape(tensor, shape)

def matmul(self, tensor1, tensor2):
return tf.matmul(tensor1, tensor2)
51 changes: 51 additions & 0 deletions orthnet/backend/_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
torch backend
"""
try:
import torch
except ImportError:
torch = None

from ._backend import Backend, assert_backend_available


class TorchBackend(Backend):

def __str__(self):
return "torch"

def is_available(self):
return torch is not None

@assert_backend_available
def is_compatible(self, args):
if list(filter(lambda t: isinstance(args, t), [
torch.FloatTensor,
torch.DoubleTensor,
torch.cuda.FloatTensor,
torch.cuda.DoubleTensor
])) != []:
return True
# , "torch backend requires input to be an instance of `torch.FloatTensor`, `torch.DoubleTensor`, `torch.cuda.FloatTensor` or `torch.cuda.DoubleTensor`"
return False

def concatenate(self, tensor, axis):
return torch.cat(tensor, dim = axis)

def ones_like(self, tensor):
return torch.ones_like(tensor)

def multiply(self, x, y):
return torch.mul(x, y)

def expand_dims(self, tensor, axis):
return tensor.unsqueeze(axis)

def get_dims(self, tensor):
return tensor.size()

def reshape(self, tensor, shape):
return tensor.view(shape)

def matmul(self, tensor1, tensor2):
return torch.matmul(tensor1, tensor2)
10 changes: 5 additions & 5 deletions orthnet/poly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .legendre import Legendre, Legendre_Normalized
from .laguerre import Laguerre
from .hermite import Hermite, Hermite2
from .chebyshev import Chebyshev, Chebyshev2
from .jacobi import Jacobi
from ._legendre import Legendre, Legendre_Normalized
from ._laguerre import Laguerre
from ._hermite import Hermite, Hermite2
from ._chebyshev import Chebyshev, Chebyshev2
from ._jacobi import Jacobi

__all__ = ['Legendre', 'Legendre_Normalized', 'Laguerre', 'Hermite', 'Hermite2', 'Chebyshev', 'Chebyshev2', 'Jacobi']
47 changes: 47 additions & 0 deletions orthnet/poly/_chebyshev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from ..backend import NumpyBackend, TensorflowBackend, TorchBackend

from .polynomial import Poly

class Chebyshev(Poly):
"""
Chebyshev polynomials of the fist kind
"""
def __init__(self, x, degree, *args, **kw):
"""
input:
- x: a tensor
- degree: highest degree of polynomial
"""
self._all_backends = list(filter(lambda backend: backend.is_available(), [TensorflowBackend(), TorchBackend(), NumpyBackend()]))
self._backend = None
for backend in self._all_backends:
if backend.is_compatible(x):
self._backend = backend
break
if self._backend is None:
raise TypeError("Cannot determine backend from input arguments of type `{1}`. Available backends are {2}".format(type(self.x), ", ".join([str(backend) for backend in self._all_backends])))
initial = [lambda x: self._backend.ones_like(x), lambda x: x]
recurrence = lambda p1, p2, n, x: self._backend.multiply(x, p1)*2 - p2
Poly.__init__(self, self._backend, x, degree, initial, recurrence, *args, **kw)

class Chebyshev2(Poly):
"""
Chebyshev polynomials of the second kind
"""
def __init__(self, x, degree, *args, **kw):
"""
input:
- x: a tensor
- degree: highest degree of polynomial
"""
self._all_backends = list(filter(lambda backend: backend.is_available(), [TensorflowBackend(), TorchBackend(), NumpyBackend()]))
self._backend = None
for backend in self._all_backends:
if backend.is_compatible(x):
self._backend = backend
break
if self._backend is None:
raise TypeError("Cannot determine backend from input arguments of type `{1}`. Available backends are {2}".format(type(self.x), ", ".join([str(backend) for backend in self._all_backends])))
initial = [lambda x: self._backend.ones_like(x), lambda x: x*2]
recurrence = lambda p1, p2, n, x: self._backend.multiply(x, p1)*2 - p2
Poly.__init__(self, self._backend, x, degree, initial, recurrence, *args, **kw)
Loading

0 comments on commit 66377a7

Please sign in to comment.