Skip to content

Commit

Permalink
Add in softmax backprop
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 4, 2023
1 parent 8c2bfab commit 9a79c8a
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 36 deletions.
95 changes: 95 additions & 0 deletions explanations/comp_graph.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Backpropagation in depth\n",
"\n",
"In the [last lesson](https://github.com/VikParuchuri/zero_to_gpt/blob/master/explanations/rnn.ipynb), we learned how to create a recurrent neural network. We now know how to build several network architectures using components like dense layers, softmax, and recurrent layers.\n",
"\n",
"We've been a bit loose with how we cover backpropagation, to make neural network architecture easier to understand. In this lesson, we'll do a deep dive into how backpropagation works. We'll do this by building a computational graph that keeps track of the different operations that transform the input data."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"To start, let's read in some data and define a 2-layer neural network that can make predictions:"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"# Read in our data, and fill missing values\n",
"data = pd.read_csv(\"../../data/clean_weather.csv\", index_col=0)\n",
"data = data.ffill()\n",
"\n",
"# Create data sets of our predictors and targets (x and y)\n",
"x = data[:10][[\"tmax\", \"tmin\", \"rain\"]].to_numpy()\n",
"y = data[:10][[\"tmax_tomorrow\"]].to_numpy()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Once we have the data, we'll initialize our parameters for 2 layers. To keep things simple, we'll omit the bias, so we just need weights for each layer:"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import numpy as np\n",
"w1 = np.random.rand(3, 3)\n",
"w2 = np.random.rand(3,1)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
36 changes: 29 additions & 7 deletions nnets/graph.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
import graphviz
from copy import deepcopy
from IPython.display import Latex
import numpy as np

def reshape_grad(orig, grad):
if grad.shape != orig.shape:
try:
summed = np.sum(grad, axis=-1)
if summed.shape != orig.shape:
summed = summed.reshape(orig.shape)
except ValueError:
summed = np.sum(grad, axis=0)
if summed.shape != orig.shape:
summed = summed.reshape(orig.shape)
return summed
return grad

class Node():
def __init__(self, *args, out=None, desc=None):
"""
Initialize a node.
:param args: Any arguments needs to compute the forward pass.
:param args: Any arguments needed to compute the forward pass.
:param out: A string representing the output of the node.
:param desc: A string describing the node.
"""
self.nodes = args
self.desc = desc
self.out = out
self.needs_grad = False
if self.desc and not self.out:
self.out = self.desc

Expand All @@ -33,7 +48,7 @@ def visualize(self, graph, backward=False):
node_id = str(id(node))
node.visualize(graph, backward=backward)
if backward:
label = f"dL/(d{node.out})"
label = f"d({node.out})"
# ensure that x and y don't get a grad label
if isinstance(node, Parameter) and not node.needs_grad:
label = None
Expand All @@ -45,8 +60,9 @@ def zero_grad(self):
"""
Zero out the gradients on each parameter.
"""
self.grad = 0
self.derivative = []
if self.needs_grad:
self.grad = None
self.derivative = []
if self.nodes is None:
return
for node in self.nodes:
Expand All @@ -72,7 +88,7 @@ def apply_bwd(self, grad):
args.append(node.apply_bwd(grad))

def generate_graph(self, backward=False):
graph = graphviz.Digraph('fwd_pass', format="png")
graph = graphviz.Digraph('fwd_pass', format="png", strict=True)
graph.attr(rankdir='LR')
self.visualize(graph, backward=backward)
return graph
Expand All @@ -81,7 +97,8 @@ def generate_derivative_chains(self, chain=None):
if chain is None:
chain = []
if self.nodes is None:
self.derivative.append(chain)
if self.needs_grad:
self.derivative.append(chain)
return
for node in self.nodes:
node_chain = deepcopy(chain)
Expand All @@ -92,7 +109,7 @@ def generate_derivative_chains(self, chain=None):
def display_partial_derivative(self):
flat_eqs = ["*".join(item) for item in self.derivative]
lhs = f"\\frac{{\partial L}}{{\partial {self.out}}}"
rhs = "+".join(flat_eqs)
rhs = " + \\\\".join(flat_eqs)
return f"{lhs} = {rhs}"

def forward(self, *args):
Expand All @@ -115,4 +132,9 @@ def forward(self):
return self.data

def backward(self, grad):
if not self.needs_grad:
return
grad = reshape_grad(self.data, grad)
if self.grad is None:
self.grad = np.zeros_like(self.data)
self.grad += grad
Loading

0 comments on commit 9a79c8a

Please sign in to comment.