I've read some answers on this question here and here, however I'm still a bit puzzled by tf.Variable being and/or not being a tf.Tensor.
The linked answers deal with a mutability of tf.Variable and mentioning that tf.Variables maintains their states (when instantiated with default parameter trainable=True).
What makes me still a bit confused is a test case I came across when writing simple unit tests using tf.test.TestCase
Consider the following code snippet. We have a simple class called Foo which has only one property, a tf.Variable initialized to w:
import tensorflow as tf
import numpy as np
class Foo:
def __init__(self, w):
self.w = tf.Variable(w)
Now, let's say you want to test that the instance of Foo has w initialized with tensor of the same dimension as passed in via w. The simplest test case could be written as follows:
import tensorflow as tf
import numpy as np
from foo import Foo
class TestFoo(tf.test.TestCase):
def test_init(self):
w = np.random.rand(3,2)
foo = Foo(w)
init = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init)
self.assertShapeEqual(w, foo.w)
if __name__ == '__main__':
tf.test.main()
Now when you run the test you'll get the following error:
======================================================================
ERROR: test_init (__main__.TestFoo)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test_foo.py", line 12, in test_init
self.assertShapeEqual(w, foo.w)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/test_util.py", line 1100, in assertShapeEqual
raise TypeError("tf_tensor must be a Tensor")
TypeError: tf_tensor must be a Tensor
----------------------------------------------------------------------
Ran 2 tests in 0.027s
FAILED (errors=1)
You can "get around" this unit test error by doing something like this (i.e. note assertShapeEqual was replaced with assertEqual):
self.assertEqual(list(w.shape), foo.w.get_shape().as_list())
What I'm interested in, though, is the tf.Variable vs tf.Tensor relationship.
What the test error seems to be suggesting is that foo.w is NOT a tf.Tensor, meaning you probably can't use tf.Tensor API on it. Consider, however, the following interactive python session:
$ python3
Python 3.6.3 (default, Oct 4 2017, 06:09:15)
[GCC 4.2.1 Compatible Apple LLVM 9.0.0 (clang-900.0.37)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> import numpy as np
>>> w = np.random.rand(3,2)
>>> var = tf.Variable(w)
>>> var.get_shape().as_list()
[3, 2]
>>> list(w.shape)
[3, 2]
>>>
In the session above, we create a variable and run the get_shape() method on it to retrieve its shape dimensions. Now, get_shape() method is a tf.Tensor API method as you can see here.
So to get back to my question, what parts of tf.Tensor API does tf.Variable implement. If the answer is ALL of them, why does the above test case fail?
self.assertShapeEqual(w, foo.w)
with
raise TypeError("tf_tensor must be a Tensor")
I'm pretty sure I'm missing something fundamental here or maybe it's a bug in assertShapeEqual ? I would appreciate if someone could shed some light on this.
I'm using following version of tensorflow on macOS with python3:
tensorflow (1.4.1)