Source code for alibi_detect.cd.pytorch.utils

from torch import nn
from typing import Callable


[docs] def activate_train_mode_for_dropout_layers(model: Callable) -> Callable: model.eval() # type: ignore n_dropout_layers = 0 for module in model.modules(): # type: ignore if isinstance(module, nn.Dropout): module.train() n_dropout_layers += 1 if n_dropout_layers == 0: raise ValueError("No dropout layers identified.") return model