Skip to content

Commit

Permalink
fix list_devices() function is not available issue (#8567)
Browse files Browse the repository at this point in the history
* fix list_devices() function is not available issue

When using tensorflow as backend and the version of tensorflow is under r1.3, the list_devices() function is not available of Session instance. This may cause some issue like keras.utils.multi_gpu_model(model, gpus) function can not work under r1.3 of tensorflow. This change is a hack to fix that issue.

* Update tensorflow_backend.py

fix pep8 check failure

* Update tensorflow_backend.py

follow fchollet review comment. use `if not`.
  • Loading branch information
luoch authored and fchollet committed Nov 24, 2017
1 parent bc28546 commit 1702d1f
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import ctc_ops as ctc
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.client import device_lib

from collections import defaultdict

Expand Down Expand Up @@ -186,6 +187,10 @@ def get_session():
v._keras_initialized = True
if uninitialized_vars:
session.run(tf.variables_initializer(uninitialized_vars))
# hack for list_devices() function.
# list_devices() function is not available under tensorflow r1.3.
if not hasattr(session, 'list_devices'):
session.list_devices = lambda: device_lib.list_local_devices()

This comment has been minimized.

Copy link
@datumbox

datumbox Nov 24, 2017

Contributor

@luoch @fchollet
The session.list_devices and the device_lib.list_local_devices() return a different output. Example:

>>> from tensorflow.python.client import device_lib
>>> l1=[d.name for d in device_lib.list_local_devices()]
>>> l1
[u'/cpu:0']
>>> from keras import backend as K
>>> l2=[d.name for d in K.get_session().list_devices()]
>>> l2
['/job:localhost/replica:0/task:0/device:CPU:0']

To make this hack more effective we need to reshape the output. Also note that list_local_devices() has some unexpected side-effects documented here: #8377

This comment has been minimized.

Copy link
@luoch

luoch Nov 27, 2017

Author Contributor

@datumbox Thank you for reminding. It is true that the ouput of device_lib.list_local_devices() need to be reshaped becausse of the output without "device:" keyword. I will submit another PR to fix it.

About the side-effects, actually, list_local_devices() will require to allocate all the rest memory after the get_session() allocated and used, and maybe there is no way to release it. So maybe this hack is not the best way to solve the issue.

But, on the Google Cloud Machine Learning Engine, the tensorflow version is r1.2. I did the hack on my own code, and it is good to me, so that I think I can submit a PR to help.

return session


Expand Down

0 comments on commit 1702d1f

Please sign in to comment.