NM Neural Decoding Demo

Decoding Motor Intent from Neural Signals

Comparing MLP, 2D CNN, LSTM, and Transformer architectures for brain-computer interface decoding

This project decodes hand position and velocity from 95-channel neural spike data recorded during a center-out reaching task. The goal: predict where a monkey is moving its hand using only brain signals. Four deep learning architectures were trained and compared on the same dataset (contdata95.mat), sampled at 50ms bins. View source on GitHub

Live Decoding Replay

Watch each model attempt to decode hand position from neural activity in real-time. Blue = actual trajectory, red = model prediction.

Neural Activity (95 channels)
Decoded Position
Correlation:

Architecture Comparison

Pearson correlation between predicted and actual kinematics on the held-out test set.

MLP
0.876
Avg Correlation
X pos: 0.878 Y pos: 0.807 X vel: 0.917 Y vel: 0.902
LSTM
0.987
Avg Correlation
X pos: 0.989 Y pos: 0.986 X vel: 0.987 Y vel: 0.985
2D CNN
0.788
Avg Correlation
X pos: 0.608 Y pos: 0.615 X vel: 0.969 Y vel: 0.960
Transformer
0.740
Avg Correlation
X pos: 0.683 Y pos: 0.692 X vel: 0.787 Y vel: 0.798

Model Architectures

Fully Connected NN

Input (95)
FC 64
FC 64
FC 128
FC 128
FC 256
Output (4)
BatchNorm + Dropout (0.3) after each layer. Single time-bin input — no temporal context needed.

2D CNN

Input (1 x 32 x 95)
Conv2D 32 (k=5)
Conv2D 64 (k=5)
Conv2D 128 (k=3)
Conv2D 128 (k=3)
AdaptiveAvgPool2D
FC 64
Output (4)
Treats (time x channels) as a grayscale image. Conv2D captures spatio-temporal patterns. BatchNorm + Dropout (0.3).

LSTM

Input (32 x 95)
BatchNorm
LSTM 128
Last Hidden State
FC → Output (4)
32-step sequence input captures 1.6s of temporal context. Only the final hidden state is used for prediction.

Transformer

Input (32 x 95)
BatchNorm
Linear (95 → 130)
TransformerEncoder (5 heads)
Mean Pool
FC → Output (4)
Encoder-only architecture with self-attention over 32 time steps. Mean pooling across the sequence for prediction.

Key Insight

Regularization matters more than architecture complexity. The original MLP used 0.5 dropout, which hurt learning. Reducing dropout to 0.3 brought the single-bin MLP to 0.88 average correlation. The LSTM dominates at 0.987 correlation, benefiting from recurrent memory over 32 time steps. The 2D CNN (0.79) captures spatio-temporal patterns but struggles with position decoding. The Transformer (0.74) underperforms here likely due to limited hyperparameter tuning and the small dataset size.

Original Results

Decoded trajectories from the notebooks. Blue = actual, Red = predicted.

MLP Decode

MLP predicted vs actual hand position

With proper regularization (dropout 0.3), the MLP achieves 0.876 average correlation from single time-bin inputs.

2D CNN Decode

2D CNN predicted vs actual hand position

The 2D CNN excels at velocity (0.97) but is weaker on position (0.61), averaging 0.79 correlation.

LSTM Decode

LSTM predicted vs actual hand position

The LSTM nearly perfectly tracks the actual trajectory, achieving 0.987 average correlation.

Transformer Decode

Transformer predicted vs actual hand position

The Transformer achieves 0.74 average correlation, likely limited by overfitting on the small dataset.