Latent Space Translation via Semantic Alignment

Zero-shot stitching of neural networks across architectures, domains, and modalities

NeurIPS 2023

Overview

Different neural networks often learn similar latent representations when exposed to semantically related data, yet this similarity is not immediately usable for transferring knowledge between models. Latent Space Translation via Semantic Alignment introduces a remarkably simple approach: directly estimating transformations between latent spaces using standard algebraic procedures with closed-form solutions.

This method enables zero-shot stitching of independently trained encoders and decoders without any retraining. Given only a small set of semantically corresponding anchor points, we can seamlessly combine networks trained on different architectures, domains, and even modalities—such as stitching text encoders with vision decoders.

Our method enables zero-shot stitching of independently trained encoders and decoders across architectures, domains, and modalities by estimating simple transformations between their latent spaces.
Method overview: We directly translate between absolute latent spaces X and Y by estimating transformation T using semantically aligned anchor points, enabling the use of arbitrarily pre-trained decoders without requiring training on relative representations.

Method: Direct Translation Between Latent Spaces

Problem Formulation

Given two latent spaces \(\mathbf{X} \in \mathbb{R}^{n \times d_1}\) and \(\mathbf{Y} \in \mathbb{R}^{n \times d_2}\), our objective is to estimate a transformation \(\mathcal{T}\) that translates between them:

\[\mathbf{Y} = \mathcal{T}(\mathbf{X})\]

This is achieved by exploiting semantic alignment through a set of parallel anchors—corresponding data points in both spaces that represent the same high-level concepts.

Pre-processing

  1. Dimension matching: Zero-pad the smaller space to match dimensions
  2. Standardization: Apply standard scaling (zero mean, unit variance) using statistics computed on anchor sets

Transformation Classes

We investigate four transformation classes, from most general to most constrained:

  • affine: General affine transformation \(\mathcal{T}(\mathbf{x}) = \mathbf{Rx} + \mathbf{b}\) optimized via gradient descent
  • linear: Linear transformation (\(\mathbf{b} = \mathbf{0}\)) with closed-form least squares solution
  • l-ortho: Orthogonalization of the linear solution via SVD
  • ortho: Optimal orthogonal transformation via Procrustes analysis

The transformation is estimated using only the anchor set, then applied to translate any point from the source to target space.


Cross-Architecture Stitching

We demonstrate the ability to stitch together encoders and decoders from completely different architectures without any retraining.

Experimental Setup

  • Vision: 6 different architectures (ResNet, ViT variants, RexNet, CLIP vision encoder)
  • Text: 7 different language models (BERT variants, RoBERTa, ALBERT, Electra, XLM-R, CLIP text encoder)
  • Datasets: CIFAR10, CIFAR100, MNIST, Fashion-MNIST, TREC, AG News, DBpedia, IMDB
  • Decoders: SVM and MLP classifiers trained on each encoder’s specific embeddings

Key Results

Using standard scaling, the orthogonal transformation (ortho) achieves near-perfect stitching performance:

  • CIFAR10: 0.93 accuracy (vs. 0.95 no-stitch baseline)
  • MNIST: 0.91 accuracy (vs. 0.96 baseline)
  • AG News: 0.66 accuracy (vs. 0.73 baseline)

The absolute baseline (no transformation) performs at random chance (0.16-0.25), confirming that latent spaces are not directly compatible. Our method recovers most of the original performance with minimal degradation.

Sensitivity to Anchor Quantity

Performance stabilizes with anchor sets comparable in size to the embedding dimension. The method is robust across a wide range of anchor quantities (100-1500), with diminishing returns beyond the embedding dimensionality.

Classification accuracy as a function of anchor set size across different datasets and transformation types. Performance stabilizes around the embedding dimension (512) with orthogonal transformations consistently achieving the best results.

Cross-Modality Stitching

A particularly striking result: stitching text encoders with image classifiers, and vice versa.

Experimental Setup

We use N24News, a multimodal news classification dataset with paired text and images. For each modality, we:

  1. Train separate classification heads (SVMs) on unimodal encodings
  2. Zero-shot stitch: apply image classifiers to translated text encodings and vice versa

Scale Distribution Analysis

Encodings from different pre-trained models exhibit Gaussian-like scale distributions with well-defined means. Vision encoders tend to have larger average norms (11-90) while language encoders cluster around 11-32.

Distribution of embedding norms across different vision and language encoders. Each encoder produces representations with characteristic scale, motivating the use of standard scaling rather than L2 normalization.

Cross-Modal Results

When translating from strong text encoders to vision decoders, we observe:

  • Text-to-vision stitching exceeds unimodal vision performance in several cases
  • RoBERTa encoding → ViT classifier: 0.75 accuracy (vs. 0.40 unimodal vision)
  • This demonstrates that a good encoder can improve a weaker decoder through translation

The asymmetry reveals that language models pre-trained on more general data produce more transferable representations than vision models.

Cross-modality stitching performance on N24News dataset. Each cell shows accuracy when translating from row encoder to column decoder. Strong text encoders (RoBERTa, XLM-R) can improve weak vision decoders through translation, sometimes exceeding native unimodal performance.

Autoencoding: Generation Tasks

Unlike prior work focusing on classification, we test latent translation for image generation by stitching autoencoders.

Experimental Setup

  • Train pairs of identical CNN autoencoders with different random seeds
  • Translate from encoder₁ to decoder₂’s latent space
  • Evaluate reconstruction quality using MSE and cosine similarity

Results

All transformation methods (affine, linear, l-ortho, ortho) produce visually similar reconstructions with low MSE:

  • MNIST: 0.02 reconstruction MSE
  • CIFAR10: 0.05 reconstruction MSE
  • CIFAR100: 0.06 reconstruction MSE

The similarity across methods suggests that autoencoder latent spaces are related by transformations broader than purely orthogonal ones, unlike the classification case where orthogonal transformations dominate.

Left: Visual reconstruction quality when stitching autoencoders across MNIST, CIFAR10, and CIFAR100. All transformation types achieve high-fidelity reconstructions. Right: Reconstruction MSE as a function of anchor set size, showing consistent performance across different anchor quantities.

Role of Normalization

We compare two normalization strategies:

Standard Scaling (Preferred)

Preserves scale information in embeddings, which proves important:

  • Vision tasks: Minimal performance difference
  • Text tasks: Significant degradation with L2 normalization (e.g., TREC: 0.79 → 0.44 with linear)

This aligns with NLP literature showing that embedding norms encode semantic information (e.g., token frequency).

L2 Normalization

Removes scale, requiring decoders to be scale-invariant. While this generalizes the transformation class, it discards potentially useful information encoded in the norm.

For classification with softmax, scale invariance is naturally achieved through the temperature parameter. Our analysis shows MLPs with monotonic activations (ReLU, tanh) maintain performance across rescaling, while non-monotonic functions (cosine) exhibit oscillatory behavior.

Impact of rescaling factor on decoder performance with different activation functions. Monotonic activations (ReLU, tanh) with softmax maintain consistent accuracy across scales, while non-monotonic functions (cosine) show oscillatory behavior, highlighting the importance of scale-invariant architectures for L2-normalized representations.

Key Findings

  1. Simple transformations suffice: Orthogonal transformations (via Procrustes analysis) effectively align latent spaces across diverse settings

  2. Zero-shot model compositionality: Combine independently trained encoders and decoders without retraining—across architectures, domains, and modalities

  3. Cross-modal translation works: Text encoders can be stitched with vision decoders and vice versa, sometimes improving performance

  4. Task-dependent transformations: Classification benefits from orthogonal transformations, while generation requires broader transformation classes

  5. Scale matters: Standard scaling outperforms L2 normalization, especially for text, suggesting norms encode semantic information

  6. Practical efficiency: Closed-form solutions (Procrustes, least squares) are fast and deterministic, requiring no hyperparameter tuning


Citation

@inproceedings{maiorca2023latent,
  title     = {Latent Space Translation via Semantic Alignment},
  author    = {Maiorca*, Valentino and Moschella*, Luca and Norelli, Antonio and
               Fumero, Marco and Locatello, Francesco and Rodol{\`a}, Emanuele},
  booktitle = {Advances in Neural Information Processing Systems},
  year      = {2023}
}

Authors

Valentino Maiorca¹ · Luca Moschella¹ · Antonio Norelli¹ · Marco Fumero¹ · Francesco Locatello² · Emanuele Rodolà¹

¹Sapienza University of Rome · ²Institute of Science and Technology Austria (ISTA)

*Equal contribution