using Flux, Random, Statistics, CUDA, Plots, JLD2 using Infiltrator CUDA.device!(1) X = load("features.jld2")["X"] y = load("labels.jld2")["y"] # n_classes = length(unique(y)) n_classes = (unique(y)) n_features = size(X, 1) struct MLP layers::Flux.Chain norm::Flux.LayerNorm end function MLP(input_dim::Int, hidden_dim::Int, output_dim::Int) model = Flux.Chain( Flux.Dense(input_dim, hidden_dim, relu), Flux.Dense(hidden_dim, output_dim) ) norm = Flux.LayerNorm(output_dim) return MLP(model, norm) end function (m::MLP)(x) out1 = m.layers(x) out2 = m.norm(out1) return out2 end function train_test_split(X, y, test_ratio::Float64) n = size(X, 2) idx = shuffle(1:n) n_test = floor(Int, test_ratio * n) test_idx = idx[1:n_test] train_idx = idx[n_test+1:end] return X[:,train_idx], y[train_idx], X[:,test_idx], y[test_idx] end function loss_fn(model, x, y) preds = model(x) # y_oh = Flux.onehotbatch(y, 1:n_classes) y_oh = Flux.onehotbatch(y, 1:maximum(y)) # @infiltrate return Flux.logitcrossentropy(preds, y_oh) end function accuracy(model, x, y) preds = model(x) pred_labels = Flux.onecold(preds) return mean(pred_labels .== y) end function train!(model, X_train, y_train, X_test, y_test; n_epochs=10, batch_size=32, lr=0.001) opt = Flux.setup(Adam(lr), model) train_losses, test_losses, test_accuracies = Float32[], Float32[], Float32[] n_train = size(X_train, 2) n_batches = div(n_train, batch_size) for epoch in 1:n_epochs batch_losses = [] shuffled_idx = shuffle(1:n_train) for b in 0:n_batches-1 batch_idx = shuffled_idx[b*batch_size + 1 : (b+1)*batch_size] x_batch = X_train[:, batch_idx] y_batch = y_train[batch_idx] x_gpu, y_gpu = gpu(x_batch), gpu(y_batch) loss_val, grads = Flux.withgradient(model) do m loss_fn(m, x_gpu, y_gpu) end # loss_val = loss_fn(model, x_gpu, y_gpu) Flux.update!(opt, model, grads[1]) push!(batch_losses, loss_val) end push!(train_losses, mean(batch_losses)) X_test_gpu, y_test_gpu = gpu(X_test), gpu(y_test) test_loss = loss_fn(model, X_test_gpu, y_test_gpu) acc = accuracy(model, X_test_gpu, y_test_gpu) push!(test_losses, test_loss) push!(test_accuracies, acc) println("epoch $(epoch)! train loss: $(train_losses[end]) test accuracy: $(test_accuracies[end])") end return train_losses, test_losses, test_accuracies end function main() input_dim = n_features hidden_dim = 32 output_dim = n_classes X_train, y_train, X_test, y_test = train_test_split(X, y, 0.2) model = MLP(input_dim, hidden_dim, output_dim) |> gpu train_losses, test_losses, test_accuracies = train!(model, X_train, y_train, X_test, y_test, n_epochs=20, batch_size=64, lr=0.001) plot(1:length(train_losses), train_losses, label="train loss", xlabel="epoch", ylabel="loss", title="train/test loss") plot!(1:length(test_losses), test_losses, label="test loss") savefig("losses.png") end main()