Skip to content

Commit

Permalink
Fix test suite failure when no tf wheel (#291)
Browse files Browse the repository at this point in the history
Resolves test module `test_dbmodel.py` using symbols from a TF import
statement that may not be present at test time (e.g. if a user installs
w/o ml backends and then runs make test-full).

[ committed by @MattToast ]
[ reviewed by @mellis13 ]
  • Loading branch information
MattToast authored May 26, 2023
1 parent 02d111f commit 8a5e940
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 44 deletions.
7 changes: 5 additions & 2 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ A full list of changes and detailed notes can be found below:

Detailed notes

- Update full test suite to no longer require a tensorflow wheel to be available
at test time. (PR291_)
- Deprecated launcher-specific orchestrators, constants, and ML utilities
were removed. (PR289_)
- Relax the coloredlogs version to be greater than 10.0 (PR288_)
- Update the Github Actions runner image from `macos-10.15`` to `macos-12``. The
former began deprecation in May 2022 and was finally removed in May 2023 (PR285_)
former began deprecation in May 2022 and was finally removed in May 2023. (PR285_)
- The Fortran tutorials had not been fully updated to show how to handle return/error
codes. These have now all been updated (PR284_)
codes. These have now all been updated. (PR284_)
- Orchestrator and Colocated DB now accept a list of interfaces to bind to. The
argument name is still `interface` for backward compatibility reasons. (PR281_)

.. _PR291: https://github.com/CrayLabs/SmartSim/pull/291
.. _PR289: https://github.com/CrayLabs/SmartSim/pull/289
.. _PR288: https://github.com/CrayLabs/SmartSim/pull/288
.. _PR285: https://github.com/CrayLabs/SmartSim/pull/285
Expand Down
79 changes: 37 additions & 42 deletions tests/backends/test_dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
from tensorflow.keras.layers import Conv2D, Input
except ImportError:
should_run_tf = False
else:

class Net(keras.Model):
def __init__(self):
super(Net, self).__init__(name="cnn")
self.conv = Conv2D(1, 3, 1)

def call(self, x):
y = self.conv(x)
return y


should_run_tf &= "tensorflow" in installed_redisai_backends()

Expand All @@ -54,18 +65,35 @@
import torch.nn.functional as F
except ImportError:
should_run_pt = False
else:
# Simple MNIST in PyTorch
class PyTorchNet(nn.Module):
def __init__(self):
super(PyTorchNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

should_run_pt &= "torch" in installed_redisai_backends()

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

class Net(keras.Model):
def __init__(self):
super(Net, self).__init__(name="cnn")
self.conv = Conv2D(1, 3, 1)

def call(self, x):
y = self.conv(x)
return y
should_run_pt &= "torch" in installed_redisai_backends()


def save_tf_cnn(path, file_name):
Expand Down Expand Up @@ -95,39 +123,6 @@ def create_tf_cnn():
return serialize_model(model)


# Simple MNIST in PyTorch
try:

class PyTorchNet(nn.Module):
def __init__(self):
super(PyTorchNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output


except Exception:
should_run_pt = False


def save_torch_cnn(path, file_name):
n = PyTorchNet()
example_forward_input = torch.rand(1, 1, 28, 28)
Expand Down

0 comments on commit 8a5e940

Please sign in to comment.