-
Notifications
You must be signed in to change notification settings - Fork 0
Added Dense and Conv BatchEnsemble layers along with unit tests and example on MNIST classification using LeNet5 #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
## Classification of MNIST dataset | ||
## with the convolutional neural network known as LeNet5. | ||
## This script also combines various | ||
## packages from the Julia ecosystem with Flux. | ||
using Flux | ||
using Flux.Data: DataLoader | ||
using Flux.Optimise: Optimiser, WeightDecay | ||
using Flux: onehotbatch, onecold, glorot_normal, label_smoothing | ||
using Flux.Losses: logitcrossentropy | ||
using Statistics, Random | ||
using Logging: with_logger | ||
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment! | ||
using ProgressMeter: @showprogress | ||
import MLDatasets | ||
import BSON | ||
using CUDA | ||
using Formatting | ||
|
||
using DeepUncertainty | ||
|
||
# LeNet5 "constructor". | ||
# The model can be adapted to any image size | ||
# and any number of output classes. | ||
function LeNet5(args; imgsize = (28, 28, 1), nclasses = 10) | ||
out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16) | ||
|
||
return Chain( | ||
ConvBatchEnsemble((5, 5), imgsize[end] => 6, args.rank, args.ensemble_size, relu), | ||
MaxPool((2, 2)), | ||
ConvBatchEnsemble((5, 5), 6 => 16, args.rank, args.ensemble_size, relu), | ||
MaxPool((2, 2)), | ||
flatten, | ||
DenseBatchEnsemble(prod(out_conv_size), 120, args.rank, args.ensemble_size, relu), | ||
DenseBatchEnsemble(120, 84, args.rank, args.ensemble_size, relu), | ||
DenseBatchEnsemble(84, nclasses, args.rank, args.ensemble_size), | ||
) | ||
end | ||
|
||
function get_data(args) | ||
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32) | ||
xtest, ytest = MLDatasets.MNIST.testdata(Float32) | ||
|
||
xtrain = reshape(xtrain, 28, 28, 1, :) | ||
xtest = reshape(xtest, 28, 28, 1, :) | ||
|
||
ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9) | ||
|
||
train_loader = DataLoader( | ||
(xtrain, ytrain), | ||
batchsize = args.batchsize, | ||
shuffle = true, | ||
partial = false, | ||
) | ||
test_loader = DataLoader((xtest, ytest), batchsize = args.batchsize, partial = false) | ||
|
||
return train_loader, test_loader | ||
end | ||
|
||
loss(ŷ, y) = logitcrossentropy(ŷ, y) | ||
|
||
function accuracy(preds, labels) | ||
acc = sum(onecold(preds |> cpu) .== onecold(labels |> cpu)) | ||
return acc | ||
end | ||
|
||
function eval_loss_accuracy(args, loader, model, device) | ||
l = [0.0f0 for x = 1:args.ensemble_size] | ||
acc = [0 for x = 1:args.ensemble_size] | ||
ece_list = [0.0f0 for x = 1:args.ensemble_size] | ||
ntot = 0 | ||
mean_l = 0 | ||
mean_acc = 0 | ||
mean_ece = 0 | ||
for (x, y) in loader | ||
x = repeat(x, 1, 1, 1, args.ensemble_size) | ||
x, y = x |> device, y |> device | ||
# Perform the forward pass | ||
ŷ = model(x) | ||
ŷ = softmax(ŷ, dims = 1) | ||
# Reshape the predictions into [classes, batch_size, ensemble_size | ||
reshaped_ŷ = reshape(ŷ, size(ŷ)[1], args.batchsize, args.ensemble_size) | ||
# Loop through each model's predictions | ||
for ensemble = 1:args.ensemble_size | ||
model_predictions = reshaped_ŷ[:, :, ensemble] | ||
# Calculate individual loss | ||
l[ensemble] += loss(model_predictions, y) * size(model_predictions)[end] | ||
acc[ensemble] += accuracy(model_predictions, y) | ||
ece_list[ensemble] += | ||
ExpectedCalibrationError(model_predictions |> cpu, onecold(y |> cpu)) * | ||
args.batchsize | ||
end | ||
# Get the mean predictions | ||
mean_predictions = mean(reshaped_ŷ, dims = ndims(reshaped_ŷ)) | ||
mean_predictions = dropdims(mean_predictions, dims = ndims(mean_predictions)) | ||
mean_l += loss(mean_predictions, y) * size(mean_predictions)[end] | ||
mean_acc += accuracy(mean_predictions, y) | ||
mean_ece += | ||
ExpectedCalibrationError(mean_predictions |> cpu, onecold(y |> cpu)) * | ||
args.batchsize | ||
ntot += size(mean_predictions)[end] | ||
end | ||
# Normalize the loss | ||
losses = [loss / ntot |> round4 for loss in l] | ||
acc = [a / ntot * 100 |> round4 for a in acc] | ||
ece_list = [x / ntot |> round4 for x in ece_list] | ||
# Calculate mean loss | ||
mean_l = mean_l / ntot |> round4 | ||
mean_acc = mean_acc / ntot * 100 |> round4 | ||
mean_ece = mean_ece / ntot |> round4 | ||
|
||
# Print the per ensemble mode loss and accuracy | ||
for ensemble = 1:args.ensemble_size | ||
@info (format( | ||
"Model {} Loss: {} Accuracy: {} ECE: {}", | ||
ensemble, | ||
losses[ensemble], | ||
acc[ensemble], | ||
ece_list[ensemble], | ||
)) | ||
end | ||
@info (format( | ||
"Mean Loss: {} Mean Accuracy: {} Mean ECE: {}", | ||
mean_l, | ||
mean_acc, | ||
mean_ece, | ||
)) | ||
@info "===========================================================" | ||
return nothing | ||
end | ||
|
||
## utility functions | ||
num_params(model) = sum(length, Flux.params(model)) | ||
round4(x) = round(x, digits = 4) | ||
|
||
# arguments for the `train` function | ||
Base.@kwdef mutable struct Args | ||
η = 3e-4 # learning rate | ||
λ = 0 # L2 regularizer param, implemented as weight decay | ||
batchsize = 32 # batch size | ||
epochs = 10 # number of epochs | ||
seed = 0 # set seed > 0 for reproducibility | ||
use_cuda = true # if true use cuda (if available) | ||
infotime = 1 # report every `infotime` epochs | ||
checktime = 5 # Save the model every `checktime` epochs. Set to 0 for no checkpoints. | ||
savepath = "runs/" # results path | ||
rank = 1 | ||
ensemble_size = 4 | ||
end | ||
|
||
function train(; kws...) | ||
args = Args(; kws...) | ||
args.seed > 0 && Random.seed!(args.seed) | ||
use_cuda = args.use_cuda && CUDA.functional() | ||
|
||
if use_cuda | ||
device = gpu | ||
@info "Training on GPU" | ||
else | ||
device = cpu | ||
@info "Training on CPU" | ||
end | ||
|
||
## DATA | ||
train_loader, test_loader = get_data(args) | ||
@info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples" | ||
|
||
## MODEL AND OPTIMIZER | ||
model = LeNet5(args) |> device | ||
@info "LeNet5 model: $(num_params(model)) trainable params" | ||
|
||
ps = Flux.params(model) | ||
|
||
opt = ADAM(args.η) | ||
if args.λ > 0 # add weight decay, equivalent to L2 regularization | ||
opt = Optimiser(WeightDecay(args.λ), opt) | ||
end | ||
|
||
function report(epoch) | ||
# @info "Train Metrics" | ||
# eval_loss_accuracy(args, train_loader, model, device) | ||
@info "Test metrics" | ||
eval_loss_accuracy(args, test_loader, model, device) | ||
end | ||
|
||
## TRAINING | ||
@info "Start Training" | ||
report(0) | ||
for epoch = 1:args.epochs | ||
@showprogress for (x, y) in train_loader | ||
# Make copies of batches for ensembles | ||
x = repeat(x, 1, 1, 1, args.ensemble_size) | ||
y = repeat(y, 1, args.ensemble_size) | ||
x, y = x |> device, y |> device | ||
gs = Flux.gradient(ps) do | ||
ŷ = model(x) | ||
loss(ŷ, y) | ||
end | ||
|
||
Flux.Optimise.update!(opt, ps, gs) | ||
end | ||
|
||
## Printing and logging | ||
epoch % args.infotime == 0 && report(epoch) | ||
end | ||
end | ||
|
||
train() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,17 @@ | ||
module DeepUncertainty | ||
|
||
using Flux | ||
using Random | ||
using Flux: @functor, glorot_normal, create_bias | ||
|
||
# Export layers | ||
export MCLayer, MCDense, MCConv | ||
export DenseBatchEnsemble, ConvBatchEnsemble | ||
export mean_loglikelihood, brier_score, ExpectedCalibrationError, prediction_metrics | ||
|
||
include("metrics.jl") | ||
include("layers/mclayers.jl") | ||
include("layers/BatchEnsemble/dense.jl") | ||
include("layers/BatchEnsemble/conv.jl") | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
ConvBatchEnsemble(filter, in => out, rank, | ||
ensemble_size, σ = identity; | ||
stride = 1, pad = 0, dilation = 1, | ||
groups = 1, [bias, weight, init]) | ||
ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank) | ||
|
||
Creates a conv BatchEnsemble layer. Batch ensemble is a memory efficient alternative | ||
for deep ensembles. In deep ensembles, if the ensemble size is N, N different models | ||
are trained, making the time and memory complexity O(N * complexity of one network). | ||
BatchEnsemble generates weight matrices for each member in the ensemble using a | ||
couple of rank 1 vectors R (alpha), S (gamma), RS' and multiplying the result with | ||
weight matrix W element wise. We also call R and S as fast weights. | ||
|
||
Reference - https://arxiv.org/abs/2002.06715 | ||
|
||
During both training and testing, we repeat the samples along the batch dimension | ||
N times, where N is the ensemble_size. For example, if each mini batch has 10 samples | ||
and our ensemble size is 4, then the actual input to the layer has 40 samples. | ||
The output of the layer has 40 samples as well, and each 10 samples can be considered | ||
as the output of an esnemble member. | ||
|
||
# Fields | ||
- `layer`: The dense layer which transforms the pertubed input to output | ||
- `alpha`: The first Fast weight of size (in_dim, ensemble_size) | ||
- `gamma`: The second Fast weight of size (out_dim, ensemble_size) | ||
- `ensemble_bias`: Bias added to the ensemble output, separate from dense layer bias | ||
- `ensemble_act`: The activation function to be applied on ensemble output | ||
- `rank`: Rank of the fast weights (rank > 1 doesn't work on GPU for now) | ||
|
||
# Arguments | ||
- `filter::NTuple{N,Integer}`: Kernel dimensions, eg, (5, 5) | ||
- `ch::Pair{<:Integer,<:Integer}`: Input channels => output channels | ||
- `rank::Integer`: Rank of the fast weights | ||
- `ensemble_size::Integer`: Number of models in the ensemble | ||
- `σ::F=identity`: Activation of the dense layer, defaults to identity | ||
- `init=glorot_normal`: Initialization function, defaults to glorot_normal | ||
- `alpha_init=glorot_normal`: Initialization function for the alpha fast weight, | ||
defaults to glorot_normal | ||
- `gamma_init=glorot_normal`: Initialization function for the gamma fast weight, | ||
defaults to glorot_normal | ||
- `bias::Bool=true`: Toggle the usage of bias in the dense layer | ||
- `ensemble_bias::Bool=true`: Toggle the usage of ensemble bias | ||
- `ensemble_act::F=identity`: Activation function for enseble outputs | ||
""" | ||
struct ConvBatchEnsemble{L,F,M,B} | ||
layer::L | ||
alpha::M | ||
gamma::M | ||
ensemble_bias::B | ||
ensemble_act::F | ||
rank::Any | ||
function ConvBatchEnsemble( | ||
layer::L, | ||
alpha::M, | ||
gamma::M, | ||
ensemble_bias = true, | ||
ensemble_act::F = identity, | ||
rank = 1, | ||
) where {M,F,L} | ||
ensemble_bias = create_bias(gamma, ensemble_bias, size(gamma)[1], size(gamma)[2]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you test it with FluxML/Flux.jl#1402 |
||
new{typeof(layer),F,M,typeof(ensemble_bias)}( | ||
layer, | ||
alpha, | ||
gamma, | ||
ensemble_bias, | ||
ensemble_act, | ||
rank, | ||
) | ||
end | ||
end | ||
|
||
function ConvBatchEnsemble( | ||
k::NTuple{N,Integer}, | ||
ch::Pair{<:Integer,<:Integer}, | ||
rank::Integer, | ||
ensemble_size::Integer, | ||
σ = identity; | ||
init = glorot_normal, | ||
alpha_init = glorot_normal, | ||
gamma_init = glorot_normal, | ||
stride = 1, | ||
pad = 0, | ||
dilation = 1, | ||
groups = 1, | ||
bias = true, | ||
ensemble_bias = true, | ||
ensemble_act = identity, | ||
Comment on lines
+73
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as last time about keeping things simple and general. Maybe it makes sense to have a constructor that takes in a Conv layer directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it does. I guess we can have both as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We actually need the input/output dimensions to create the alpha/gamma matrices. Might as well keep them in the signature, or we'll have to infer them from the conv layer's struct and that might change anytime in flux source ? |
||
) where {N} | ||
layer = Flux.Conv( | ||
k, | ||
ch, | ||
σ; | ||
stride = stride, | ||
pad = pad, | ||
dilation = dilation, | ||
init = init, | ||
groups = groups, | ||
bias = bias, | ||
) | ||
in_dim = ch[1] | ||
out_dim = ch[2] | ||
if rank >= 1 | ||
alpha_shape = (in_dim, ensemble_size) | ||
gamma_shape = (out_dim, ensemble_size) | ||
else | ||
error("Rank must be >= 1.") | ||
end | ||
alpha = alpha_init(alpha_shape) | ||
gamma = gamma_init(gamma_shape) | ||
|
||
return ConvBatchEnsemble(layer, alpha, gamma, ensemble_bias, ensemble_act, rank) | ||
end | ||
|
||
@functor ConvBatchEnsemble | ||
|
||
function (be::ConvBatchEnsemble)(x) | ||
# Conv Batch Ensemble params | ||
layer = be.layer | ||
alpha = be.alpha | ||
gamma = be.gamma | ||
e_b = be.ensemble_bias | ||
e_σ = be.ensemble_act | ||
|
||
batch_size = size(x)[end] | ||
in_size = size(alpha)[1] | ||
out_size = size(gamma)[1] | ||
ensemble_size = size(alpha)[2] | ||
samples_per_model = batch_size ÷ ensemble_size | ||
|
||
# Alpha, gamma shapes - [units, ensembles, rank] | ||
e_b = repeat(e_b, samples_per_model) | ||
alpha = repeat(alpha, samples_per_model) | ||
gamma = repeat(gamma, samples_per_model) | ||
# Reshape alpha, gamma to [units, batch_size, rank] | ||
e_b = reshape(e_b, (1, 1, out_size, batch_size)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Size of the bias seems relevant here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we know that the shape of the bias allocated can fit into the container its expected to be in |
||
alpha = reshape(alpha, (1, 1, in_size, batch_size)) | ||
gamma = reshape(gamma, (1, 1, out_size, batch_size)) | ||
|
||
perturbed_x = x .* alpha | ||
output = layer(perturbed_x) .* gamma | ||
output = e_σ.(output .+ e_b) | ||
|
||
return output | ||
end |
Uh oh!
There was an error while loading. Please reload this page.