Papers
Topics
Authors
Recent
Search
2000 character limit reached

Joint Expert Embedding Training (JEET)

Updated 10 February 2026
  • JEET is a modular deep learning framework that shares a common embedding backbone with specialized decoders, optimizing parameter efficiency and mitigating catastrophic forgetting.
  • It employs joint optimization combining language modeling, knowledge distillation, and multitask learning to balance general representation and expert specialization.
  • JEET streamlines inference by activating a single expert module per input, significantly reducing computational overhead in complex multi-domain pipelines.

Joint Expert Embedding Training (JEET) is a modular deep learning paradigm designed to maximize parameter efficiency and avoid catastrophic forgetting across multi-domain or multi-task settings, while retaining expert-level specialization for each subtask. In JEET, a single shared embedding or backbone encoder is paired with multiple downstream expert modules which are trained jointly, allowing high-capacity feature extraction to be amortized across specialized decoders (experts) tuned for specific languages, domains, or tasks. The JEET framework has been explored across natural language processing and speech diarization, enabling modular yet unified approaches for complex pipelines (Pálka et al., 2024, Al-Maamari et al., 2024).

1. Conceptual Foundations and Motivation

JEET emerges from the intersection of modular Mixture-of-Experts (MoE), knowledge distillation, and multi-task learning. Standard MoE architectures assign inputs to expert subnetworks via a gating or routing function, yielding parameter isolation and specialization. Conventional training of multitask pipelines often involves training separate embeddings and decoders for each task, which leads to duplicated parameters and increased risk of catastrophic forgetting when sequentially trained. JEET addresses these by introducing a joint optimization regime in which a common embedding layer is shared among multiple expert decoders, each responsible for a distinct domain or subtask. This configuration enables expert-level performance, improved memory and compute efficiency, and robust knowledge retention across tasks (Al-Maamari et al., 2024).

In the language modeling context, JEET allows every expert module (e.g., for English, French, German, or Python) to share the same high-dimensional token embeddings, thus reducing redundancy without sacrificing specialization. In speaker diarization, analogous parameter efficiency is achieved by sharing a ResNet-based encoder for extracting frame-level features utilized in parallel by heads for speaker embedding, voice activity detection (VAD), and overlap speech detection (OSD) (Pálka et al., 2024).

2. Mathematical Formulation and Architectural Elements

At the core of JEET is a shared embedding layer or backbone encoder, with expert-specific decoders downstream:

  • LLMs: For vocabulary size VV and embedding dimension dd, the shared embedding matrix ERV×dE \in \mathbb{R}^{V \times d} transforms token sequences to embedding sequences. For sequence x=[x1,,xT]x = [x_1, \ldots, x_T], output is H=[h1,...,hT]H = [h_1, ..., h_T] with ht=E[xt]h_t = E[x_t]. Predicted logits for expert jj are (j)=TransformerDecoderθj(H)\ell^{(j)} = \mathrm{TransformerDecoder}_{\theta_j}(H) (Al-Maamari et al., 2024).
  • Router/Gating Mechanism: A sequence-level classifier (router) determines expert selection. Given pooled embedding z=Pool(H)Rdz = \mathrm{Pool}(H) \in \mathbb{R}^d, scores are s=Uz+bRMs = U z + b \in \mathbb{R}^M for MM experts; softmax yields routing probabilities gj(x)g_j(x). At inference, j=argmaxjgj(x)j^\star = \arg\max_j g_j(x), and only Expertj\mathrm{Expert}_{j^\star} operates on xx.
  • Speaker Diarization: A ResNet-101 encoder computes 8192-dimensional features from audio frames (64-d Mel filterbanks, every 10 ms); these are mapped through a linear layer to 256-dimensional per-frame embeddings. Three heads—speaker embedding, VAD, and OSD—access these embeddings in parallel (Pálka et al., 2024).

Joint Training Objective

  • LLMs: Each expert is distilled from a common teacher (e.g., GPT-2 Medium) using two losses:
    • Language modeling loss: LLM=t=1TlogpS(j)(ytx<t)L_{\mathrm{LM}} = -\sum_{t=1}^T \log p_S^{(j^\star)}(y_t | x_{<t})
    • Distillation loss: LKD=t=1TDKL(pT(x<t)pS(j)(x<t))L_{\mathrm{KD}} = \sum_{t=1}^T D_{KL}(p_T(\cdot|x_{<t}) \| p_S^{(j^\star)}(\cdot|x_{<t}))
    • Total loss: Ltotal=αLLM+βLKDL_{\mathrm{total}} = \alpha L_{\mathrm{LM}} + \beta L_{\mathrm{KD}} (typically α\alpha, β\beta ≈ 0.5)
  • Speaker Diarization: The network is trained on a weighted sum of three objectives (Pálka et al., 2024):

    • Additive Angular Margin (AAM, ArcFace) loss for speaker identity on VoxCeleb2:

    Lembed=1Ni=1Nlogexp(scos(θyi,i+m))exp(scos(θyi,i+m))+jyiexp(scosθj,i)L_{\mathrm{embed}} = -\frac{1}{N} \sum_{i=1}^N \log \frac{ \exp(s \cos (\theta_{y_i,i} + m)) }{ \exp(s \cos (\theta_{y_i,i} + m)) + \sum_{j \neq y_i} \exp(s \cos \theta_{j,i}) } - Binary cross-entropy for VAD and OSD, with LVADL_\mathrm{VAD} and LOSDL_\mathrm{OSD} over appropriate subsets. - Joint loss: Ltotal=αLembed+βLVAD+γLOSDL_{\mathrm{total}} = \alpha L_{\mathrm{embed}} + \beta L_{\mathrm{VAD}} + \gamma L_{\mathrm{OSD}} (empirically, α=1,β=5,γ=2\alpha=1, \beta=5, \gamma=2)

3. Training Procedures and Regimes

JEET employs joint optimization strategies that maintain balanced exposure of all experts to their respective domains:

  • Batch Construction: Training batches are language- or task-pure; only the relevant expert is updated per batch, allowing the shared embedding to learn generalizable representations while decoders retain specificity.
  • Curricular Phasing: In speaker diarization, the encoder is first pre-trained as a standard per-segment model, then adapted for per-frame extraction; joint heads are fine-tuned in a second phase with alternated supervision (VoxCeleb2 for speaker loss, compound diarization corpora for VAD/OSD).
  • Optimizer Choices: Language JEET utilizes Adam; diarization JEET combines SGD (for speaker loss) and AdamW (for VAD/OSD).
  • Loss Schedules: Combined loss regimes slightly outperform alternating-loss schedules in sequence modeling (Al-Maamari et al., 2024).

4. Architectural Scale and Inference Behavior

A hallmark of JEET is parameter amortization:

Component JEET-LLMs JEET-Speaker Diarization
Shared Layer Embedding table (32k×768) ResNet-101 body
Per-Expert Layer GPT-2-110M decoder ×4 Linear heads for VAD, OSD
Router 4×7684 \times 768 classifier Not applicable
Total Parameters ≈464M (with 4 experts) As per ResNet+heads

Inference is streamlined: in language JEET, only one expert is active per sequence (selected by the router). In diarization, a single forward pass produces all embedding and detection outputs, reducing inference time by ∼3× over traditional modular pipelines.

5. Empirical Performance and Comparative Analysis

JEET achieves perplexity and diarization error rates competitive with more resource-heavy or sequentially trained baselines:

  • Language Modeling (Perplexity, lower is better) (Al-Maamari et al., 2024):

    Architecture English French German Python
    PLE 74.09 20.30 39.86 28.92
    JEET 75.79 20.12 40.38 27.02
    MoE-CE 78.96 20.91 41.92 27.16

JEET and PLE (separately distilled students) are nearly equivalent. MoE-CE (with an additional common expert) slightly underperforms.

  • Router Performance: TF-IDF-based router achieves 99.95% accuracy across expert classes.

  • Catastrophic Forgetting: JEET shows 0% forgetting under joint training, vs. up to 38% in sequentially trained single-student baselines.
  • Speaker Diarization (DER, lower is better) (Pálka et al., 2024):

    System DER Miss FA Conf
    Baseline (modular) 26.2% NA NA NA
    Joint model + VBx 26.6% 15.3% 4.5% 6.8%

Inference on 50 minutes of meeting audio is reduced from ~29 min (baseline) to ~9 min (JEET-style joint model).

6. Advantages, Limitations, and Practical Considerations

Advantages:

  • Parameter efficiency via shared embeddings/backbone.

  • Expert-level specialization per domain/language/task.
  • Drastically reduced risk of catastrophic forgetting under joint training or balanced mini-batching.
  • Streamlined inference and memory profile; only expert and embedding loaded per input.
  • Modular extensibility; additional experts or tasks can be appended with minimal changes.

Limitations:

  • Multiplicative decoder parameter count (e.g., 4× the cost of a single decoder for four experts).
  • Router overhead and need for reliable separation of domains.
  • Necessity of balanced data for each expert; otherwise, representation and performance may become skewed.
  • No gradient flow between isolated experts; only the embedding is shared.
  • The shared embedding, being the largest single block, may require dedicated memory management.

Implementation notes include offloading large embedding matrices between CPU/GPU, pretraining or joint training the router, and the use of GPU-accelerated toolkits for efficient parameter distribution.

7. Extensions and Relation to Broader Research

JEET can be viewed as a generalization of modular multi-head architectures, and is a concrete path toward fully unified pipelines in complex settings such as diarization, multilingual modeling, or multitask ASR. In diarization, future JEET developments may integrate discriminative clustering objectives (e.g., DVBx), permutation-invariant losses (EEND), or multi-domain pretraining (Pálka et al., 2024). In language, incremental improvements arise from router/embedding variants or inclusion of additional domains.

JEET thus occupies a middle ground between monolithic models with maximal parameter sharing and the complete separation of single-task or single-domain pipelines, offering near-state-of-the-art accuracy with robust multi-domain retention and efficient resource utilization (Pálka et al., 2024, Al-Maamari et al., 2024).

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Joint Expert Embedding Training (JEET).