import argparse
import contextlib
import importlib
import json
import logging
import os
import socket
import sys
import time
from distutils.util import strtobool
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
from seldon_core import __version__
from seldon_core import wrapper as seldon_microservice
from seldon_core.flask_utils import (
ANNOTATION_REST_TIMEOUT,
ANNOTATIONS_FILE,
DEFAULT_ANNOTATION_REST_TIMEOUT,
SeldonMicroserviceException,
)
from seldon_core.gunicorn_utils import (
StandaloneApplication,
UserModelApplication,
accesslog,
post_worker_init,
threads,
worker_exit,
)
from seldon_core.metrics import SeldonMetrics
from seldon_core.utils import getenv_as_bool, setup_tracing
# This is related to how multiprocessing is implemeneted on MacOS
# See https://github.com/SeldonIO/seldon-core/issues/3410 for discussion.
USE_MULTIPROCESS_ENV_NAME = "USE_MULTIPROCESS_PACKAGE"
USE_MULTIPROCESS = getenv_as_bool(USE_MULTIPROCESS_ENV_NAME, default=False)
if USE_MULTIPROCESS:
import multiprocess as mp
else:
import multiprocessing as mp
logger = logging.getLogger(__name__)
PARAMETERS_ENV_NAME = "PREDICTIVE_UNIT_PARAMETERS"
HTTP_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_HTTP_SERVICE_PORT"
GRPC_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_GRPC_SERVICE_PORT"
METRICS_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_METRICS_SERVICE_PORT"
FILTER_METRICS_ACCESS_LOGS_ENV_NAME = "FILTER_METRICS_ACCESS_LOGS"
LOG_LEVEL_ENV = "SELDON_LOG_LEVEL"
DEFAULT_LOG_LEVEL = "INFO"
DEFAULT_GRPC_PORT = 5000
DEFAULT_HTTP_PORT = 9000
DEFAULT_METRICS_PORT = 6000
DEBUG_ENV = "SELDON_DEBUG"
GUNICORN_ACCESS_LOG_ENV = "GUNICORN_ACCESS_LOG"
[docs]def start_servers(
target1: Callable, target2: Callable, target3: Callable, metrics_target: Callable
) -> None:
"""
Start servers
Parameters
----------
target1
Main flask process
target2
Auxiliary flask process
"""
if USE_MULTIPROCESS:
logger.info("Using alternative multiprocessing library")
else:
logger.info("Using standard multiprocessing library")
p2 = None
if target2:
p2 = mp.Process(target=target2, daemon=False)
p2.start()
p3 = None
if target3:
p3 = mp.Process(target=target3, daemon=True)
p3.start()
p4 = None
if metrics_target:
p4 = mp.Process(target=metrics_target, daemon=True)
p4.start()
target1()
if p2:
p2.join()
if p3:
p3.join()
if p4:
p4.join()
[docs]def parse_parameters(parameters: Dict) -> Dict:
"""
Parse the user object parameters
Parameters
----------
parameters
Returns
-------
"""
type_dict = {
"INT": int,
"FLOAT": float,
"DOUBLE": float,
"STRING": str,
"BOOL": bool,
}
parsed_parameters = {}
for param in parameters:
name = param.get("name")
value = param.get("value")
type_ = param.get("type")
if type_ == "BOOL":
parsed_parameters[name] = bool(strtobool(value))
else:
try:
parsed_parameters[name] = type_dict[type_](value)
except ValueError:
raise SeldonMicroserviceException(
"Bad model parameter: "
+ name
+ " with value "
+ value
+ " can't be parsed as a "
+ type_,
reason="MICROSERVICE_BAD_PARAMETER",
)
except KeyError:
raise SeldonMicroserviceException(
"Bad model parameter type: "
+ type_
+ " valid are INT, FLOAT, DOUBLE, STRING, BOOL",
reason="MICROSERVICE_BAD_PARAMETER",
)
return parsed_parameters
[docs]def load_annotations() -> Dict[str, str]:
"""
Attempt to load annotations
Returns
-------
"""
annotations = {}
try:
if os.path.isfile(ANNOTATIONS_FILE):
with open(ANNOTATIONS_FILE, "r") as ins:
for line in ins:
line = line.rstrip()
parts = list(map(str.strip, line.split("=", 1)))
if len(parts) == 2:
key = parts[0]
value = parts[1][1:-1] # strip quotes at start and end
logger.info("Found annotation %s:%s ", key, value)
annotations[key] = value
else:
logger.info("Bad annotation [%s]", line)
except:
logger.error("Failed to open annotations file %s", ANNOTATIONS_FILE)
return annotations
[docs]class MetricsEndpointFilter(logging.Filter):
[docs] def filter(self, record):
return seldon_microservice.METRICS_ENDPOINT not in record.getMessage()
[docs]def setup_logger(log_level: str, debug_mode: bool) -> logging.Logger:
# set up log level
log_level_raw = os.environ.get(LOG_LEVEL_ENV, log_level.upper())
log_level_num = getattr(logging, log_level_raw, None)
if not isinstance(log_level_num, int):
raise ValueError("Invalid log level: %s", log_level)
logger.setLevel(log_level_num)
# Set right level on access logs
flask_logger = logging.getLogger("werkzeug")
flask_logger.setLevel(log_level_num)
if getenv_as_bool(FILTER_METRICS_ACCESS_LOGS_ENV_NAME, default=not debug_mode):
flask_logger.addFilter(MetricsEndpointFilter())
gunicorn_logger = logging.getLogger("gunicorn.access")
gunicorn_logger.addFilter(MetricsEndpointFilter())
logger.debug("Log level set to %s:%s", log_level, log_level_num)
# set log level for the imported microservice type
seldon_microservice.logger.setLevel(log_level_num)
logging.getLogger().setLevel(log_level_num)
for handler in logger.handlers:
handler.setLevel(log_level_num)
return logger
[docs]def parse_args() -> Tuple[argparse.Namespace, List[str]]:
parser = argparse.ArgumentParser()
parser.add_argument("interface_name", type=str, help="Name of the user interface.")
parser.add_argument(
"--service-type",
type=str,
choices=["MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR"],
default="MODEL",
)
parser.add_argument(
"--persistence",
nargs="?",
default=0,
const=1,
type=int,
help="deprecated argument ",
)
parser.add_argument(
"--parameters", type=str, default=os.environ.get(PARAMETERS_ENV_NAME, "[]")
)
parser.add_argument(
"--log-level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default=DEFAULT_LOG_LEVEL,
help="Log level of the inference server.",
)
parser.add_argument(
"--debug",
nargs="?",
type=bool,
default=getenv_as_bool(DEBUG_ENV, default=False),
const=True,
help="Enable debug mode.",
)
parser.add_argument(
"--tracing",
nargs="?",
default=int(os.environ.get("TRACING", "0")),
const=1,
type=int,
)
# gunicorn settings, defaults are from
# http://docs.gunicorn.org/en/stable/settings.html
parser.add_argument(
"--workers",
type=int,
default=int(os.environ.get("GUNICORN_WORKERS", "1")),
help="Number of Gunicorn workers for handling requests.",
)
parser.add_argument(
"--threads",
type=int,
default=int(os.environ.get("GUNICORN_THREADS", "1")),
help="Number of threads to run per Gunicorn worker.",
)
parser.add_argument(
"--max-requests",
type=int,
default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")),
help="Maximum number of requests gunicorn worker will process before restarting.",
)
parser.add_argument(
"--max-requests-jitter",
type=int,
default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")),
help="Maximum random jitter to add to max-requests.",
)
parser.add_argument(
"--keepalive",
type=int,
default=int(os.environ.get("GUNICORN_KEEPALIVE", "2")),
help="The number of seconds to wait for requests on a Keep-Alive connection.",
)
parser.add_argument(
"--single-threaded",
type=int,
default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")),
help="Force the Flask app to run single-threaded. Also applies to Gunicorn.",
)
parser.add_argument(
"--http-port",
type=int,
default=int(os.environ.get(HTTP_SERVICE_PORT_ENV_NAME, DEFAULT_HTTP_PORT)),
help="Set http port of seldon service",
)
parser.add_argument(
"--grpc-port",
type=int,
default=int(os.environ.get(GRPC_SERVICE_PORT_ENV_NAME, DEFAULT_GRPC_PORT)),
help="Set grpc port of seldon service",
)
parser.add_argument(
"--metrics-port",
type=int,
default=int(
os.environ.get(METRICS_SERVICE_PORT_ENV_NAME, DEFAULT_METRICS_PORT)
),
help="Set metrics port of seldon service",
)
parser.add_argument(
"--pidfile", type=str, default=None, help="A file path to use for the PID file"
)
parser.add_argument(
"--access-log",
nargs="?",
type=bool,
default=getenv_as_bool(GUNICORN_ACCESS_LOG_ENV, default=False),
const=True,
help="Enable gunicorn access log.",
)
parser.add_argument(
"--grpc-threads",
type=int,
default=os.environ.get("GRPC_THREADS", default="1"),
help="Number of GRPC threads per worker.",
)
parser.add_argument(
"--grpc-workers",
type=int,
default=os.environ.get("GRPC_WORKERS", default="1"),
help="Number of GPRC workers.",
)
return parser.parse_known_args()
def _make_rest_server_debug(
user_object: Any,
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
jaeger_extra_tags: List[str],
) -> Callable[[], None]:
"""Makes a function that creates a REST debugging server.
Args:
user_object: an instance of user-defined class, inherited from user_model.SeldonComponent.
seldon_metrics: a SeldonMetrics instance.
args: parsed args from commandline.
jaeger_extra_tags:
"""
def server():
app = seldon_microservice.get_rest_microservice(user_object, seldon_metrics)
try:
user_object.load()
except (NotImplementedError, AttributeError):
pass
if args.tracing:
logger.info("Tracing branch is active")
from flask_opentracing import FlaskTracing
tracer = setup_tracing(args.interface_name)
logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
FlaskTracing(tracer, True, app, jaeger_extra_tags)
# Timeout not supported in flask development server
app.run(
host="0.0.0.0",
port=args.http_port,
threaded=False if args.single_threaded else True,
)
return server
def _make_rest_server_prod(
user_object: Any,
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
jaeger_extra_tags: List[str],
annotations: Dict[str, str],
) -> Callable[[], None]:
"""Makes a function that creates a REST production server.
Args:
user_object: an instance of user-defined class, inherited from user_model.SeldonComponent.
seldon_metrics: a SeldonMetrics instance.
args: parsed args from commandline.
jaeger_extra_tags:
annotations:
"""
def server() -> None:
rest_timeout = DEFAULT_ANNOTATION_REST_TIMEOUT
if ANNOTATION_REST_TIMEOUT in annotations:
# Gunicorn timeout is in seconds so convert as annotation is in miliseconds
rest_timeout = int(annotations[ANNOTATION_REST_TIMEOUT]) / 1000
# Converting timeout from float to int and set to 1 if is 0
rest_timeout = int(rest_timeout) or 1
options = {
"bind": "%s:%s" % ("0.0.0.0", args.http_port),
"accesslog": accesslog(args.access_log),
"loglevel": args.log_level.lower(),
"timeout": rest_timeout,
"threads": threads(args.threads, args.single_threaded),
"workers": args.workers,
"max_requests": args.max_requests,
"max_requests_jitter": args.max_requests_jitter,
"post_worker_init": post_worker_init,
"worker_exit": partial(worker_exit, seldon_metrics=seldon_metrics),
"keepalive": args.keepalive,
}
logger.info(f"Gunicorn Config: {options}")
if args.pidfile is not None:
options["pidfile"] = args.pidfile
app = seldon_microservice.get_rest_microservice(user_object, seldon_metrics)
UserModelApplication(
app,
user_object,
args.tracing,
jaeger_extra_tags,
args.interface_name,
options=options,
).run()
return server
def _wait_forever(server):
try:
while True:
time.sleep(60 * 60)
except KeyboardInterrupt:
server.stop(None)
def _run_grpc_server(
user_object: Any,
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
annotations: Dict[str, str],
bind_address: str,
):
"""Start a server in a subprocess."""
logger.info(f"Starting new GRPC server with {args.grpc_threads} threads.")
if args.tracing:
from grpc_opentracing import open_tracing_server_interceptor
logger.info("Adding tracer")
tracer = setup_tracing(args.interface_name)
interceptor = open_tracing_server_interceptor(tracer)
else:
interceptor = None
server = seldon_microservice.get_grpc_server(
user_object,
seldon_metrics,
annotations=annotations,
trace_interceptor=interceptor,
num_threads=args.grpc_threads,
)
try:
user_object.load()
except (NotImplementedError, AttributeError):
pass
server.add_insecure_port(bind_address)
server.start()
_wait_forever(server)
def _make_grpc_server(
user_object: Any,
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
annotations: Dict[str, str],
) -> Callable[[], None]:
def server() -> None:
with _reserve_grpc_port(args.grpc_port) as bind_port:
bind_address = "0.0.0.0:{}".format(bind_port)
logger.info(
"GRPC Server Binding to %s with %d processes.",
bind_address,
args.grpc_workers,
)
sys.stdout.flush()
workers = []
for _ in range(args.grpc_workers):
# NOTE: It is imperative that the worker subprocesses be forked before
# any gRPC servers start up. See
# https://github.com/grpc/grpc/issues/16001 for more details.
worker = mp.Process(
target=_run_grpc_server,
args=(
user_object,
seldon_metrics,
args,
annotations,
bind_address,
),
)
worker.start()
workers.append(worker)
for worker in workers:
worker.join()
return server
def _make_rest_metrics_server(
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
) -> Callable[[], None]:
def server() -> None:
app = seldon_microservice.get_metrics_microservice(seldon_metrics)
if args.debug:
app.run(host="0.0.0.0", port=args.metrics_port)
else:
options = {
"bind": "%s:%s" % ("0.0.0.0", args.metrics_port),
"accesslog": accesslog(args.access_log),
"loglevel": args.log_level.lower(),
"timeout": 5000,
"max_requests": args.max_requests,
"max_requests_jitter": args.max_requests_jitter,
"post_worker_init": post_worker_init,
"keepalive": args.keepalive,
}
if args.pidfile is not None:
options["pidfile"] = args.pidfile
StandaloneApplication(app, options=options).run()
return server
@contextlib.contextmanager
def _reserve_grpc_port(grpc_port: int):
"""Find and reserve a port for all subprocesses to use."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) != 1:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", grpc_port))
try:
yield sock.getsockname()[1]
finally:
sock.close()
[docs]def main():
LOG_FORMAT = (
"%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s: %(message)s"
)
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
logger.info("Starting microservice.py:main")
logger.info(f"Seldon Core version: {__version__}")
sys.path.append(os.getcwd())
args, remaining = parse_args()
if len(remaining) > 0:
logger.error(
f"Unknown args {remaining}. Note since 1.5.0 this CLI does not take API type (REST, GRPC)"
)
sys.exit(-1)
parameters = parse_parameters(json.loads(args.parameters))
setup_logger(args.log_level, args.debug)
# set flask trace jaeger extra tags
jaeger_extra_tags = list(
filter(
lambda x: (x != ""),
[tag.strip() for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",")],
)
)
logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
annotations = load_annotations()
logger.info("Annotations: %s", annotations)
parts = args.interface_name.rsplit(".", 1)
if len(parts) == 1:
logger.info("Importing %s", args.interface_name)
interface_file = importlib.import_module(args.interface_name)
user_class = getattr(interface_file, args.interface_name)
else:
logger.info("Importing submodule %s", parts)
interface_file = importlib.import_module(parts[0])
user_class = getattr(interface_file, parts[1])
if args.persistence:
logger.error(f"Persistence: ignored, persistence is deprecated")
user_object = user_class(**parameters)
http_port = args.http_port
seldon_metrics = SeldonMetrics(worker_id_func=os.getpid)
# TODO why 2 ways to create metrics server
# seldon_metrics = SeldonMetrics(
# worker_id_func=lambda: threading.current_thread().name
# )
if args.debug:
# Start Flask debug server
logger.info(
"REST microservice running on port %i single-threaded=%s",
http_port,
args.single_threaded,
)
server_rest_func = _make_rest_server_debug(
user_object, seldon_metrics, args, jaeger_extra_tags=jaeger_extra_tags
)
else:
# Start production server
logger.info("REST gunicorn microservice running on port %i", http_port)
server_rest_func = _make_rest_server_prod(
user_object,
seldon_metrics,
args,
jaeger_extra_tags=jaeger_extra_tags,
annotations=annotations,
)
server_grpc_func = None
if args.grpc_workers > 0:
server_grpc_func = _make_grpc_server(
user_object, seldon_metrics, args, annotations=annotations
)
logger.info("REST metrics microservice running on port %i", args.metrics_port)
server_metrics_func = _make_rest_metrics_server(seldon_metrics, args)
if hasattr(user_object, "custom_service") and callable(
getattr(user_object, "custom_service")
):
server_custom_func = user_object.custom_service
else:
server_custom_func = None
logger.info("Starting servers")
start_servers(
server_rest_func,
server_grpc_func,
server_custom_func,
server_metrics_func,
)
if __name__ == "__main__":
main()