Save and Restore a tf.estimator for inference
Serialize your tf.estimator as a tf.saved_model for a 100x speedup
githubCode available on github. A more advanced use of the techniques covered in this article can also be found here.
Topics tf.estimator, serving_input_receiver_fn, ServingInputReceiver, export_saved_model, contrib.predictor, from_saved_model
Outline
- Introduction
- Model
- Training the Estimator
- Reload and Predict (first attempt)
- The problem
- A clever fix?
- Exporting the estimator as a tf.saved_model
- Reload and Predict (the good way)
- Conclusion and next steps
Introduction
The tf.estimator
framework is really handy to train and evaluate a model on a given dataset. In this post, I show how a simple tensorflow script can get a state-of-the-art model up and running.
However, when it comes to using your trained Estimator
to get predictions on the fly, things get a little bit messier.
This blog post demonstrates how to properly serialize, reload a tf.estimator
and predict on new data, by going over a dummy example (fully reproductible by cloning the github repo) and get a 100x speedup over the vanilla implementation.
Good news is: in the end, it is dead simple and only takes a few lines of code đ.
Model
Letâs say that we have trained an estimator that computes
\[f([x, x]) = 2x\]We model this as a simple dense layer with one output. In other words, our model has 2 parameters a
and b
to learn such that
Using the tf.estimator
paradigm, here is our model_fn
def model_fn(features, labels, mode, params):
if isinstance(features, dict): # For serving
features = features['feature']
predictions = tf.layers.dense(features, 1)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
else:
loss = tf.nn.l2_loss(predictions - labels)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(
mode, loss=loss)
elif mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.train.AdamOptimizer(learning_rate=0.5).minimize(
loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode, loss=loss, train_op=train_op)
else:
raise NotImplementedError()
If you think this is a rather preposterous use of Tensorflow and Deep Learning in general, why not have a look at this very serious article on a related topic?
If you need an introduction to tf.estimator
, you can read my introduction with an emphasis on NLP here.
Training the Estimator
To train our model, we generate fake data using tf.data
(here is a short yet comprehensive introduction to tf.data).
def train_generator_fn():
for number in range(100):
yield [number, number], [2 * number]
def train_input_fn():
shapes, types = (2, 1), (tf.float32, tf.float32)
dataset = tf.data.Dataset.from_generator(
train_generator_fn, output_types=types, output_shapes=shapes)
dataset = dataset.batch(20).repeat(200)
return dataset
Once we have our model_fn
and our train_input_fn
, training our tf.estimator
is a matter of 2 lines of code.
estimator = tf.estimator.Estimator(model_fn, 'model', params={})
estimator.train(train_input_fn)
As expected, training takes a few seconds and manages to learn the parameters a
and b
as indicated by a very small loss.
At the end of the training, because we specified the
model_dir
argument of theEstimator
, themodel
directory contains full checkpoints of the graph.
Reload and Predict (first attempt)
Now, letâs say that we have a service, exterior to our model, that keeps sending us new data. Everytime we receive a new example, we want to run our model. Letâs fake the service by using a python generator
def my_service():
for number in range(100, 110):
yield number
Because we donât know in advance our full dataset, our code needs to look like
for number in my_service():
prediction = get_prediction(number)
Imagine a Flask app that calls the
get_prediction
function everytime some new data is sent to some url.
Letâs use the predict
method of the tf.estimator.Estimator
class.
We first create a special input_fn
that formats the new number for the Estimator
def example_input_fn(number):
dataset = tf.data.Dataset.from_generator(
lambda: ([number, number] for _ in range(1)),
output_types=tf.float32, output_shapes=(2,))
iterator = dataset.batch(1).make_one_shot_iterator()
next_element = iterator.get_next()
return next_element, None
This is the same
input_fn
as the one used for training except that this time the data generator only yields the number sent by our service.
And we can get predictions by doing
for nb in my_service():
example_inpf = functools.partial(example_input_fn, nb)
for pred in estimator.predict(example_inpf):
print(pred)
The
predict
method returns a generator. Because our dataset only yields one example, the loop is executed only once and it seems like we achieved our goal: we used the estimator to predict the outcome on new data.
The problem
Now, letâs have a look at the logs.
What? Everytime we call predict
, our estimator
instance reloads the weights from disk! Thus, it takes an astonishing 0.19s
per loop execution! There must be a better way⌠Keep in mind that this model only has 2 parameters, what will happen when you have a BFN*?
*starts with Big, ends with Network
A clever fix?
If you have good python skills, you might notice that everything in this pipeline seems to be built on top of python generators. We could control the iteration of these pipelined generators using the next
method.
This clever solution has been investigated by some people, see here for instance.
The idea, overall, is to build the predict
generator on top of the service
generator (like Russian dolls) and move to the next prediction when we receive a new example. However, apart from being hacky, this method has a downside: chaining 2 generators is not as reliable, as the iteration of the second might depend on a batch of the first. In other words, the estimator.predict
seems to be built with some kind of batching mechanism, which, in the end, causes problems (and I donât want to have to look at the details of this method, as I shouldnât have to).
Even if there are some workarounds and you could probably make this work eventually, it requires custom hacks depending on your data pipeline.
Now, letâs explore a much better option (which also seems to be the official one, even though the guides and documentation are pretty scarce and vague about the subject. But this sadly wonât come as a surprise and is also the existential motivation of this very blog post).
Exporting the estimator as a tf.saved_model
See the official guide.
Tensorflow provides a more efficient way of serializing any inference graph that plays nicely with the rest of the ecosystem, like Tensorflow Serving.
In line with the tf.estimator
technical specifications of making it an easy-to-use, high-level API, exporting an Estimator
as a saved_model
is really simple.
We first need to define a special input_fn
(as always we canât expect the estimator to guess how to format the data).
def serving_input_receiver_fn():
"""Serving input_fn that builds features from placeholders
Returns
-------
tf.estimator.export.ServingInputReceiver
"""
number = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='number')
receiver_tensors = {'number': number}
features = tf.tile(number, multiples=[1, 2])
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
The user will provide
number
to the model, which will be fed to the placeholder and transformed tofeatures
. Ultimately, the model will receive a dictionary{'feature': features}
. NB: the data that the user provides in that case is already batched.
Then, reloading and serializing the estimator is straightforward
estimator = tf.estimator.Estimator(model_fn, 'model', params={})
estimator.export_saved_model('saved_model', serving_input_receiver_fn)
For each new export, you will find a new time-stamped subdirectory (here 1543003465
), containing the graph definition as a protobuffer (saved_model.pb
) along with the weights (in variables/
). If your graph uses other resources (like vocab files for lookup tables, that sort of thing), you will also find them under an asset/
directory.
Reload and Predict (the good way)
While the saved model can be used directly by tools like Tensorflow serving, some people (including me) might want to reload it in python. After all, maybe your application is built in Flask.
First, letâs find the latest set of weights by exploring the subdirectories under saved_model
export_dir = 'saved_model'
subdirs = [x for x in Path(export_dir).iterdir()
if x.is_dir() and 'temp' not in str(x)]
latest = str(sorted(subdirs)[-1])
If you donât know / donât use the
pathlib
module (python3 only), try using it. It bundles a lot ofos.path
functionnality (and more) in a much nicer and easy-to-use package. I started using it after reading about it on this blog which also has a lot of other excellent articles.
Once we have found the directory containing the latest set of weights, we can use a predictor
to reload the weights. I heard from this very simple (yet incredibly powerful) class on this stackoverlow answer.
from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(latest)
for nb in my_service():
pred = predict_fn({'number': [[nb]]})['output']
The
predictor
class also comes with afrom_estimator
method!
Under the hood, this predictor class implementation uses all the tools mentionned in the official guide. It uses a tf.saved_model.loader
to load the tf.saved_model
into a session, reloads the serving signature from the protobuffer, extract the input and output tensors of the graph and bundles everything in a nice callable for ease of use.
We could probably have gleaned enough information here and there in the official documentation and hack it ourselves, but a better implementation is already there. Why not advertise it more on the official guide? Anyway, a big thank you to the developers that made it available!
Conclusion and next steps
It is still a mystery to me why the tf.estimator
API does not offer an efficient predict method for on-the-fly requests, or at least advertise a bit more the tools covered in this blog post. After all, with a few lines of code you get an even better result! The relative lack of documentation and official guides is not new: it is probably really hard to keep up with the rapid evolution of the framework while still offering comprehensive and coherent documentation to its users.
Using the predictor
as explained above yields a 100x speedup on our dummy example!
If youâre curious to see a more âreal-lifeâ use of these methods, my other repo using tf.estimator for NER (Named Entity Recognition using a bi-LSTM + CRF) implements them all!