Skip to content
Pablo Rodriguez

Training Data

Training Neural Networks: The 3-Step Process

Section titled “Training Neural Networks: The 3-Step Process”
  • Specify how to compute output given input x and parameters w, b
  • Similar to logistic regression where we defined:
    • f(x) = g(w·x + b) where g is sigmoid function
    • g(z) = 1/(1+e^(-z)) where z = w·x + b
  • For neural networks, we define the architecture with:
Define Neural Network
model = tf.keras.Sequential([
tf.keras.layers.Dense(25, activation='sigmoid'),
tf.keras.layers.Dense(15, activation='sigmoid'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
  • This code specifies:
    • 25 hidden units in first layer
    • 15 hidden units in second layer
    • 1 output unit in final layer
    • All using sigmoid activation
  • Define loss function for single training example

  • Define cost function as average loss over entire training set

  • For binary classification:

    • Loss function: -y log(f(x)) - (1-y)log(1-f(x))
    • In TensorFlow, this is “binary cross-entropy loss”
  • Compile model with:

    Compile Model
    model.compile(loss=tf.keras.losses.BinaryCrossentropy())
  • For regression problems, use mean squared error:

    • Loss function: (1/2)(f(x) - y)²

    • In TensorFlow:

      Regression Loss
      model.compile(loss=tf.keras.losses.MeanSquaredError())
  • Use gradient descent to minimize cost J(W,B)

  • Update parameters: W = W - α·∂J/∂W, B = B - α·∂J/∂B

  • TensorFlow computes derivatives using backpropagation

  • Train with:

    Train Model
    model.fit(X, y, epochs=100)
  • TensorFlow actually uses algorithms even faster than gradient descent

  • Most commercial implementations use libraries like TensorFlow or PyTorch
  • Similar to how developers now use libraries for:
    • Sorting algorithms
    • Mathematical operations (square roots)
    • Matrix operations
  • Understanding internal mechanisms still valuable for debugging

TensorFlow simplifies the neural network training process into three clear steps: defining the architecture, specifying the loss function, and minimizing the cost function. While libraries handle much of the complexity, understanding the underlying mechanics helps troubleshoot unexpected behaviors.