Neural nets in TensorFlow.
TensorFlow is a relatively low-level framework for building and executing computational graphs. There are higher-level frameworks built on top of TensorFlow that implement neural networks (e.g. Keras).
This repository implements one such high-level framework. In research, I need low-level control over the computational graph, but I also often reuse basic neural net architectures. This makes it inconvenient to use pre-existing frameworks (lack of low-level control) -- but I don't want to reimplement basic nn components each time.
Why would you use this repository?
- You are a Tensorflow beginner, and want to see how to implement stuff. I certainly learned a lot looking at other people's tf code!
- If you want to use my implementations in your own projects please do, though you'll probably learn more and get best results if you do your own coding.
I tried to mimic the scikit-learn interface. You fit a network using nn.fit, and predict with nn.predict. In some cases there are other useful methods, e.g. GANs can gan.sample(). See individual module documentation for more details.
Each module's usage is exemplified in its __main__ part. For example, fcnn.py contains a section which uses a Fully Convolutional Neural Network (FCNN) with batch normalization and residual connections to denoise MNIST images:
[...] # Code that defines the FCNN.
[...] # Load MNIST data and add random uniform(0, 1) noise at 20% of the pixels.
# Define the graph.
fcnn = FCNN(x_shape = ims_tr.shape[1:])
# Create a Tensorflow session and train the net.
with tf.Session() as sess:
# Define the Tensorflow session, and its initializer op.
sess.run(tf.global_variables_initializer())
# Use a writer object for Tensorboard visualization.
summary = tf.summary.merge_all()
writer = tf.summary.FileWriter('logs/fcnn')
writer.add_graph(sess.graph)
# Fit the net.
fcnn.fit(X_tr, Y_tr, sess, epochs=100, writer=writer, summary=summary)
# Predict.
Y_pred = fcnn.predict(X_ts, sess).reshape(-1, 28, 28)
[...] # More code that plots the results.
This trains an FCNN. You can use Tensorboard to visualize the graph. For example, the image below illustrates the graph and zooms onto one specific batch norm -- residual layer:
Our FCNN learned to denoise noisy MNIST very well:
At the moment, the reposity contains the following methods:
- nn.py: Multi-layer perceptron (MLP) with Dropout (arXiv:1207.0580).
- nn.py: Residual Network (arXiv:1512.03385).
- nn.py: Highway Network (arXiv:1505.00387).
- gan.py: Least-Squares Generative Adversarial Network (arXiv:1611.04076v2, in my experience the best GAN, though doesn't a convergence criterion like Wasserstein GANs).
- cgan.py: Conditional Least-Squares Generative Adversarial Network (arXiv:1411.1784)
- mtn.py: Multi-Task Networks (my own creation) -- learn from multiple datasets with related inputs but different output tasks.
- fcnn.py: Fully-convolutional neural nets.
Everything should work with Python 2 and 3.
- NumPy >= 1.12
- TensorFlow >= 1.0.0