from functools import partial
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from typing import Dict, List
[docs]def hidden_state_embedding(hidden_states: torch.Tensor, layers: List[int],
use_cls: bool, reduce_mean: bool = True) -> torch.Tensor:
"""
Extract embeddings from hidden attention state layers.
Parameters
----------
hidden_states
Attention hidden states in the transformer model.
layers
List of layers to use for the embedding.
use_cls
Whether to use the next sentence token (CLS) to extract the embeddings.
reduce_mean
Whether to take the mean of the output tensor.
Returns
-------
Tensor with embeddings.
"""
hs = [hidden_states[layer][:, 0:1, :] if use_cls else hidden_states[layer] for layer in layers]
hs = torch.cat(hs, dim=1) # type: ignore
y = hs.mean(dim=1) if reduce_mean else hs # type: ignore
return y