from alibi_detect.utils.missing_optional_dependency import import_optional
from typing import Union
from typing_extensions import Literal, Protocol, runtime_checkable
# Use Protocols instead of base classes for the backend associated objects. This is a bit more flexible and allows us to
# avoid the torch/tensorflow imports in the base class.
TransformProtocolType = Union[TransformProtocol, FittedTransformProtocol]
NormalizerLiterals = Literal['PValNormalizer', 'ShiftAndScaleNormalizer']
AggregatorLiterals = Literal['TopKAggregator', 'AverageAggregator',
'MaxAggregator', 'MinAggregator']
PValNormalizer, ShiftAndScaleNormalizer, TopKAggregator, AverageAggregator, \
MaxAggregator, MinAggregator = import_optional(
'alibi_detect.od.pytorch.ensemble',
['PValNormalizer', 'ShiftAndScaleNormalizer', 'TopKAggregator',
'AverageAggregator', 'MaxAggregator', 'MinAggregator']
)
[docs]def get_normalizer(normalizer: Union[TransformProtocolType, NormalizerLiterals]) -> TransformProtocol:
if isinstance(normalizer, str):
try:
return {
'PValNormalizer': PValNormalizer,
'ShiftAndScaleNormalizer': ShiftAndScaleNormalizer,
}.get(normalizer)()
except KeyError:
raise NotImplementedError(f'Normalizer {normalizer} not implemented.')
return normalizer
[docs]def get_aggregator(aggregator: Union[TransformProtocol, AggregatorLiterals]) -> TransformProtocol:
if isinstance(aggregator, str):
try:
return {
'TopKAggregator': TopKAggregator,
'AverageAggregator': AverageAggregator,
'MaxAggregator': MaxAggregator,
'MinAggregator': MinAggregator,
}.get(aggregator)()
except KeyError:
raise NotImplementedError(f'Aggregator {aggregator} not implemented.')
return aggregator