I wrote up the full explanation and math and even implemented super-fast Triton kernels:
www.pisoni.ai/posts/scaled...
Small-scale runs look surprisingly stable, but obviously, we need to scale up experiments to see how these dynamics hold up.
Let me know if you ca help with that!
Posts by Raphael Pisoni
Because RBF Attention actively avoids the explosion of pre-norm magnitudes, you don't need QK-norm. The network maintains healthy gradient flow natively, and you drop the weird tangent-drift artifacts entirely.
That term acts as a built-in, dynamic L2 regularizer. The architecture itself actively penalizes large keys. It naturally keeps vector magnitudes in check without ever forcing a projection onto a hypersphere.
Scaled RBF Attention offers a pretty clean workaround. If you measure similarity using negative squared Euclidean distance (-||Q - K||²) instead of a dot product, due to the softmax we end up with (Q⋅K)−||K||².
Because of this, with every single optimization step, the pre-norm vector length increases. As these vectors inflate, the gradients shrink toward zero. This geometric drift is one reason for why we actually need weight decay. To constantly drag parameter norms back down.
But why would those unnormalized vectors get so big in the first place? Think about the geometry. When your optimizer takes a linear step on a circular surface (the normalized hypersphere), the update vector inherently points outward along the tangent.
The problem hides in the backward pass. While the forward pass becomes scale-invariant, the gradients flowing backward through the norm layer actually scale inversely with the unnormalized inputs.
Big pre-norm vectors --> tiny gradients.
QK-norm is definitely the industry standard right now. It stops "magnitude bullying" by projecting queries and keys onto a hypersphere before taking the dot product and makes training more stable. Sounds great right?
I never really considered how dangerous QK-norm actually is before working on RBF Attention. While solving some obvious issues, it can be the cause of some much less obvious ones.🧵
I've packed up everything so you can try it on your own data and use cases. Let me know what you find!!
Blog-post: pisoni.ai/posts/halo/
Code: github.com/4rtemi5/halo
BTW I'm also currently looking for ML research roles so let me know if you know something or someone! ;)
These kinds of results are huge for safety critical classification, fighting SSL-collapse or for multimodal models like CLIP that need robust rejection thresholds. If you're working on representation learning and have the resource to scale these results up let me know.
Usually, better Out-of-Distribution (OOD) detection means you pay a "safety tax" and lose base accuracy. With HALO, that penalty vanishes.
On CIFAR-10 (ResNet-18):
• Base accuracy is maintained
• ECE drops from ~8% down to 1.5%
• OOD False Positives (FPR@95) cut by >50%
With a little high school math we can make this loss super efficient and even avoid the curse of dimensionality that plagues these kinds of losses.
Then wire a virtual K+1 class permanently to the origin. This creates a mathematically sound, zero-parameter "Abstain Class". Outlier data naturally falls into this origin sink.
This leads to radial explosion and a jagged latent space and allows the model to overconfidently hallucinate on pure noise.
The fix? First switch to L2 distances instead of dot-products. Maximum confidence now maps to a finite location (distance = 0) rather than an asymptote.
The root of the problem are the inputs to the Categorical Cross-Entropy: Unconstrained Dot-Products. Because of how the Softmax function works, to reach 100% confidence, the network is forced to push its features infinitely far away.
There is no finite "perfect" score.
Neural networks have a fundamental problem. Feed them garbage data and instead of admitting that they are confused, they will confidently hallucinate.
I just open-sourced the HALO-Loss to try and fix this. It give the model a mathematically sound *I don't know!* button.🧵
I dove deeper into the rabbit hole of RBF-Attention. I refined the Triton kernel, added register-tokens and developed SuSiE positional embedding as a replacement for RoPE in Euclidean space.
Go have a look at the repo or the blogpost in the comments if you're interested! :)
Yes it's all in the there! Probably much more than you need to get started! :)
FYI: bsky.app/profile/4rte...
I'm open sourcing all my code for scaled RBF-Attention. If you want to roast my triton knowledge or want to check how far you have to scale things to make them break, feel free to have a look!😅
github.com/4rtemi5/rbf_...
Obvious caveats before I sound crazy:
Scale: I've only tested this on tiny models and data. It all might collapse on larger models.
Performance: The triton kernel works but it's a first pass. It won't outperform proper Flash-Attention anytime soon.
I trained a few models on the tiny-stories dataset and they not only converge well, but consistently seem to outperform standard SDPA.🤯
To get a semi-fair comparison of the speed and memory reqs I implemented the whole thing as a custom Triton kernel (heavily inspired by Flash-Attention by Tri Dao) and my first shot fares quite well against the hyper optimized torch implementation.
Because of that matmul the heavy lifting is still fast on GPU and the L2 norms are just lightweight additions you can broadcast on top.
Computing the full NxN distance matrix sounds like a great way to OOM and tank your speed but using a simple math trick you can expand the squared distance and use a matmul:
||q-k||² = ||q||² + ||k||² - 2(q@k.T)
The geometric intuition is pretty simple: standard attention uses dot-product similarity, which means it cares about angle and magnitude between key and query.
RBF uses the euclidean distance instead, so it measures how close the vecors actually are in space.
Plot of two loss curves were rbf-attention consistently outperforms scaled dot-product attention.
For some reason I decided to swap out standard dot-product attention for a scaled-rbf kernel.
Pretty much expected it to fail to converge or be impossibly slow but the scaled-rbf-attention is getting unexpectedly good results?? 👇
Sorry for vagueposting but I was too excited to not post at all.😅 Will share more soon!🤞