Code is available on github.
Who never experienced cryptic messages when developing some functionality in Tensorflow (or Theano, pyTorch, etc.)? I’m sure I’m not the only one having spent countless hours trying to understand these error messages to debug my code. Of course one gets better with time, and one is able to interpret them correctly in 90% of the use cases. But what about the other 10%? More importantly, even if don’t get any error and your Tensorflow code runs smoothly,
How can you be sure that your function performs the intended operation?
There are indeed cases where some minor mistake (or a big misunderstanding) in a call to some tensorflow primitive breaks everything. Mechanically it runs, but it is logically wrong and does not perform what you need. These non-breaking bugs can be extremely hard to nail down! And those are more usual than one might think! This is one of the main motivations for the recentely released Open AI Baselines. Here is what Szymon Sidor and John Schulman point out
Be wary of non-breaking bugs: when we looked through a sample of ten popular reinforcement learning algorithm reimplementations we noticed that six had subtle bugs found by a community member and confirmed by the author. These ranged from mild bugs that ignored gradients on some examples or implemented causal convolutions incorrectly to serious ones that reported scores higher than the true result.
In other words, if we listen to experienced researchers and programmers, the motivation for writing tests is deeply rooted in the quest for
- subtle bugs and correctness (for higher quality code and better research)
- time saving (as probably 150% of my programming teachers keep saying).
First, let’s briefly discuss our options regarding testing:
- run the whole pipeline on some input and examine the output
- adopt a new coding strategy, by breaking functionnalities into independent functions and test each chunk independently.
Option 1 might be the most common approach. As long as the code runs, overfits a small dataset and achieves good performance on a reference dataset, one might assume that it is correct. However, it does not address some subtle bugs and that’s why we’ll cover option 2.
In option 2. we’ll try to follow a few guidelines. First, we want to test our functions against the most general use-case possible. A good way of doing this is to write some randomized test that samples different inputs and check the outputs. The second principle is based on another simple observation: most of the time, the code that you write can be implemented in Numpy (or even Tensorflow) using a highly non-optimized, naive logic (with for loops for instance).
In other words, a good procedure is to follow these steps
for _ in range(number of random tests): sample x assert tensorflow_function(x) == naive_function(x)
As we never truly listen to people as long as our stuff works and also because we overestimate the difficulty of changing our habits (some might call it laziness),
let’s go over a quick example to help get you started!
Now, let’s take a concrete example in which we want to extract entries of a Tensorflow tensor (advanced indexing is far from being a sinecure in Tensorflow). We want to write a function
get_entry in Tensorflow that extracts entries of a batched tensor. In other words, given
- a tensor
[batch, d1, d2]
- a vector
- a vector
we want to compute a tensor
o of shape
[batch] such that
Notice that without the batch dimension, this is just getting the entry $ (i, j)$ of a matrix $ A $.
The numpy code for this logic is straightforward, as we don’t care about performance
import numpy as np def get_entry_np(t, indices_d1, indices_d2, batch_size): result = np.zeros(batch_size) for i in range(batch_size): result[i] = t[i, indices_d1[i], indices_d2[i]] return result
The same logic can be achieved in Tensorflow with the gather_nd function
import tensorflow as tf def get_entry_tf(t, indices_d1, indices_d2, batch_size): indices = tf.stack([tf.range(batch_size), indices_d1, indices_d2], axis=1) return tf.gather_nd(t, indices)
To write a test, we’ll use the standard unittest library. First, import the library with
and define a class that will contain our different tests
class Test(unittest.TestCase): def test_1(self): # always passes self.assertEqual(True, True) def test_2(self): pass # etc. if __name__ == '__main__': unittest.main()
In our case, we will just write one test, that follows the randomized logic. We first sample the size of the input, then sample a tensor
t and the indices and then record if the output of the
tensorflow functions are equal!
def test_get_entry(self): success = True for _ in range(10): # sample input batch_size, d1, d2 = map(int, np.random.randint(low=2, high=100, size=3)) test_input = np.random.random([batch_size, d1, d2]) test_indices_d1 = np.random.randint(low=0, high=d1-1, size=[batch_size]) test_indices_d2 = np.random.randint(low=0, high=d2-1, size=[batch_size]) # evaluate the numpy version test_result = get_entry_np(test_input, test_indices_d1, test_indices_d2, batch_size) # evaluate the tensorflow version with tf.Session() as sess: tf_input = tf.constant(test_input, dtype=tf.float32) tf_indices_d1 = tf.constant(test_indices_d1, dtype=tf.int32) tf_indices_d2 = tf.constant(test_indices_d2, dtype=tf.int32) tf_result = get_entry_tf(tf_input, tf_indices_d1, tf_indices_d2, batch_size) tf_result = sess.run(tf_result) # check that outputs are similar success = success and np.allclose(test_result, tf_result) self.assertEqual(success, True)
Now let’s execute the
.py file containing our test class.
… and success! We are glad to check that our non-intuitive Tensorflow implementation of this simple trick matches the numpy version. Now, we are (almost) sure that if something breaks down in the code, it won’t be this function’s fault!
I was surprised of how quick it was to implement a test (not more than 10 minutes), a time neglectible compared to what I spent coming up with the right implementation for our
get_entry function! Thanks to the tests, you can iterate much faster and test a few solutions and possibilities for the arguments, without over-thinking about it, and get to the desired solution much faster! To go further in test-driven development, a common piece of advise is to write tests before leaving the workplace, for the functionnalities that you want to implement on the next day! That way, if you’re sure about the functionnality and can implement a naive implementation, you’ll be able to bootstrap in the morning and quickly focus on what really matters! There is also one benefit of TDD that we haven’t really mentionned so far: it encourages you to break your code into small functions that can be unit-tested. Your code will be more general, shorter and more reusable!