diff --git a/tabpy/models/__init__.py b/tabpy/models/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/tabpy/models/deploy_models.py b/tabpy/models/deploy_models.py deleted file mode 100644 index a41171a7..00000000 --- a/tabpy/models/deploy_models.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -from pathlib import Path -import pip -import platform -import subprocess -import sys -from tabpy.models.utils import setup_utils - - -def main(): - # Determine if we run python or python3 - if platform.system() == "Windows": - py = "python" - else: - py = "python3" - - if len(sys.argv) > 1: - config_file_path = sys.argv[1] - else: - config_file_path = setup_utils.get_default_config_file_path() - print(f"Using config file at {config_file_path}") - port, auth_on, prefix = setup_utils.parse_config(config_file_path) - if auth_on: - auth_args = setup_utils.get_creds() - else: - auth_args = [] - - directory = str(Path(__file__).resolve().parent / "scripts") - # Deploy each model in the scripts directory - for filename in os.listdir(directory): - subprocess.run([py, f"{directory}/{filename}", config_file_path] + auth_args) - - -if __name__ == "__main__": - main() diff --git a/tabpy/models/scripts/ANOVA.py b/tabpy/models/scripts/ANOVA.py deleted file mode 100644 index 4cf90c2b..00000000 --- a/tabpy/models/scripts/ANOVA.py +++ /dev/null @@ -1,22 +0,0 @@ -import scipy.stats as stats -from tabpy.models.utils import setup_utils - - -def anova(_arg1, _arg2, *_argN): - """ - ANOVA is a statistical hypothesis test that is used to compare - two or more group means for equality.For more information on - the function and how to use it please refer to tabpy-tools.md - """ - - cols = [_arg1, _arg2] + list(_argN) - for col in cols: - if not isinstance(col[0], (int, float)): - print("values must be numeric") - raise ValueError - _, p_value = stats.f_oneway(_arg1, _arg2, *_argN) - return p_value - - -if __name__ == "__main__": - setup_utils.deploy_model("anova", anova, "Returns the p-value form an ANOVA test") diff --git a/tabpy/models/scripts/PCA.py b/tabpy/models/scripts/PCA.py deleted file mode 100644 index df23632c..00000000 --- a/tabpy/models/scripts/PCA.py +++ /dev/null @@ -1,60 +0,0 @@ -import pandas as pd -from numpy import array -from sklearn.decomposition import PCA as sklearnPCA -from sklearn.preprocessing import StandardScaler -from sklearn.preprocessing import LabelEncoder -from sklearn.preprocessing import OneHotEncoder -from tabpy.models.utils import setup_utils - - -def PCA(component, _arg1, _arg2, *_argN): - """ - Principal Component Analysis is a technique that extracts the key - distinct components from a high dimensional space whie attempting - to capture as much of the variance as possible. For more information - on the function and how to use it please refer to tabpy-tools.md - """ - cols = [_arg1, _arg2] + list(_argN) - encodedCols = [] - labelEncoder = LabelEncoder() - oneHotEncoder = OneHotEncoder(categories="auto", sparse=False) - - for col in cols: - if isinstance(col[0], (int, float)): - encodedCols.append(col) - elif type(col[0]) is bool: - intCol = array(col) - encodedCols.append(intCol.astype(int)) - else: - if len(set(col)) > 25: - print( - "ERROR: Non-numeric arguments cannot have more than " - "25 unique values" - ) - raise ValueError - integerEncoded = labelEncoder.fit_transform(array(col)) - integerEncoded = integerEncoded.reshape(len(col), 1) - oneHotEncoded = oneHotEncoder.fit_transform(integerEncoded) - transformedMatrix = oneHotEncoded.transpose() - encodedCols += list(transformedMatrix) - - dataDict = {} - for i in range(len(encodedCols)): - dataDict[f"col{1 + i}"] = list(encodedCols[i]) - - if component <= 0 or component > len(dataDict): - print("ERROR: Component specified must be >= 0 and " "<= number of arguments") - raise ValueError - - df = pd.DataFrame(data=dataDict, dtype=float) - scale = StandardScaler() - scaledData = scale.fit_transform(df) - - pca = sklearnPCA() - pcaComponents = pca.fit_transform(scaledData) - - return pcaComponents[:, component - 1].tolist() - - -if __name__ == "__main__": - setup_utils.deploy_model("PCA", PCA, "Returns the specified principal component") diff --git a/tabpy/models/scripts/SentimentAnalysis.py b/tabpy/models/scripts/SentimentAnalysis.py deleted file mode 100644 index ed4e0c7e..00000000 --- a/tabpy/models/scripts/SentimentAnalysis.py +++ /dev/null @@ -1,52 +0,0 @@ -from textblob import TextBlob -import nltk -from nltk.sentiment.vader import SentimentIntensityAnalyzer -from tabpy.models.utils import setup_utils - - -import ssl - -_ctx = ssl._create_unverified_context -ssl._create_default_https_context = _ctx - - -nltk.download("vader_lexicon") -nltk.download("punkt") - - -def SentimentAnalysis(_arg1, library="nltk"): - """ - Sentiment Analysis is a procedure that assigns a score from -1 to 1 - for a piece of text with -1 being negative and 1 being positive. For - more information on the function and how to use it please refer to - tabpy-tools.md - """ - if not (isinstance(_arg1[0], str)): - raise TypeError - - supportedLibraries = {"nltk", "textblob"} - - library = library.lower() - if library not in supportedLibraries: - raise ValueError - - scores = [] - if library == "nltk": - sid = SentimentIntensityAnalyzer() - for text in _arg1: - sentimentResults = sid.polarity_scores(text) - score = sentimentResults["compound"] - scores.append(score) - elif library == "textblob": - for text in _arg1: - currScore = TextBlob(text) - scores.append(currScore.sentiment.polarity) - return scores - - -if __name__ == "__main__": - setup_utils.deploy_model( - "Sentiment Analysis", - SentimentAnalysis, - "Returns a sentiment score between -1 and 1 for " "a given string", - ) diff --git a/tabpy/models/scripts/__init__.py b/tabpy/models/scripts/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/tabpy/models/scripts/tTest.py b/tabpy/models/scripts/tTest.py deleted file mode 100644 index 433a4750..00000000 --- a/tabpy/models/scripts/tTest.py +++ /dev/null @@ -1,39 +0,0 @@ -from scipy import stats -from tabpy.models.utils import setup_utils - - -def ttest(_arg1, _arg2): - """ - T-Test is a statistical hypothesis test that is used to compare - two sample means or a sample’s mean against a known population mean. - For more information on the function and how to use it please refer - to tabpy-tools.md - """ - # one sample test with mean - if len(_arg2) == 1: - test_stat, p_value = stats.ttest_1samp(_arg1, _arg2) - return p_value - # two sample t-test where _arg1 is numeric and _arg2 is a binary factor - elif len(set(_arg2)) == 2: - # each sample in _arg1 needs to have a corresponding classification - # in _arg2 - if not (len(_arg1) == len(_arg2)): - raise ValueError - class1, class2 = set(_arg2) - sample1 = [] - sample2 = [] - for i in range(len(_arg1)): - if _arg2[i] == class1: - sample1.append(_arg1[i]) - else: - sample2.append(_arg1[i]) - test_stat, p_value = stats.ttest_ind(sample1, sample2, equal_var=False) - return p_value - # arg1 is a sample and arg2 is a sample - else: - test_stat, p_value = stats.ttest_ind(_arg1, _arg2, equal_var=False) - return p_value - - -if __name__ == "__main__": - setup_utils.deploy_model("ttest", ttest, "Returns the p-value form a t-test") diff --git a/tabpy/models/utils/__init__.py b/tabpy/models/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tabpy/models/utils/setup_utils.py b/tabpy/models/utils/setup_utils.py deleted file mode 100644 index cb69c8e5..00000000 --- a/tabpy/models/utils/setup_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -import configparser -import getpass -import os -import sys -from tabpy.tabpy_tools.client import Client - - -def get_default_config_file_path(): - import tabpy - - pkg_path = os.path.dirname(tabpy.__file__) - config_file_path = os.path.join(pkg_path, "tabpy_server", "common", "default.conf") - return config_file_path - - -def parse_config(config_file_path): - config = configparser.ConfigParser() - config.read(config_file_path) - tabpy_config = config["TabPy"] - - port = 9004 - if "TABPY_PORT" in tabpy_config: - port = tabpy_config["TABPY_PORT"] - - auth_on = "TABPY_PWD_FILE" in tabpy_config - ssl_on = ( - "TABPY_TRANSFER_PROTOCOL" in tabpy_config - and "TABPY_CERTIFICATE_FILE" in tabpy_config - and "TABPY_KEY_FILE" in tabpy_config - ) - prefix = "https" if ssl_on else "http" - return port, auth_on, prefix - - -def get_creds(): - if sys.stdin.isatty(): - user = input("Username: ") - passwd = getpass.getpass("Password: ") - else: - user = sys.stdin.readline().rstrip() - passwd = sys.stdin.readline().rstrip() - return [user, passwd] - - -def deploy_model(funcName, func, funcDescription): - # running from deploy_models.py - if len(sys.argv) > 1: - config_file_path = sys.argv[1] - else: - config_file_path = get_default_config_file_path() - port, auth_on, prefix = parse_config(config_file_path) - - connection = Client(f"{prefix}://localhost:{port}/") - - if auth_on: - # credentials are passed in from setup.py - if len(sys.argv) == 4: - user, passwd = sys.argv[2], sys.argv[3] - # running Sentiment Analysis independently - else: - user, passwd = get_creds() - connection.set_credentials(user, passwd) - - connection.deploy(funcName, func, funcDescription, override=True) - print(f"Successfully deployed {funcName}") diff --git a/tabpy/tabpy_server/app/app.py b/tabpy/tabpy_server/app/app.py index 0ee807ad..cfbfaa5a 100644 --- a/tabpy/tabpy_server/app/app.py +++ b/tabpy/tabpy_server/app/app.py @@ -4,6 +4,7 @@ from logging import config import multiprocessing import os +import pkg_resources import shutil import signal import sys @@ -14,8 +15,6 @@ from tabpy.tabpy_server.app.util import parse_pwd_file from tabpy.tabpy_server.management.state import TabPyState from tabpy.tabpy_server.management.util import _get_state_from_file -from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server -from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler from tabpy.tabpy_server.handlers import ( EndpointHandler, EndpointsHandler, @@ -23,7 +22,6 @@ QueryPlaneHandler, ServiceInfoHandler, StatusHandler, - UploadDestinationHandler, ) import tornado @@ -60,6 +58,7 @@ class TabPyApp: tabpy_state = None python_service = None credentials = {} + models = {} def __init__(self, config_file=None): if config_file is None: @@ -74,6 +73,7 @@ def __init__(self, config_file=None): logging.basicConfig(level=logging.DEBUG) self._parse_config(config_file) + self._load_models() def run(self): application = self._create_tornado_web_app() @@ -82,8 +82,6 @@ def run(self): ) logger.info(f"Setting max request size to {max_request_size} bytes") - init_model_evaluator(self.settings, self.tabpy_state, self.python_service) - protocol = self.settings[SettingsParameters.TransferProtocol] ssl_options = None if protocol == "https": @@ -122,12 +120,6 @@ def try_exit(self): tornado.ioloop.IOLoop.instance().stop() logger.info("Shutting down TabPy...") - logger.info("Initializing TabPy...") - tornado.ioloop.IOLoop.instance().run_sync( - lambda: init_ps_server(self.settings, self.tabpy_state) - ) - logger.info("Done initializing TabPy.") - executor = concurrent.futures.ThreadPoolExecutor( max_workers=multiprocessing.cpu_count() ) @@ -157,11 +149,6 @@ def try_exit(self): EvaluationPlaneHandler, dict(executor=executor, app=self), ), - ( - self.subdirectory + r"/configurations/endpoint_upload_destination", - UploadDestinationHandler, - dict(app=self), - ), ( self.subdirectory + r"/(.*)", tornado.web.StaticFileHandler, @@ -306,7 +293,6 @@ def _parse_config(self, config_file): ) state_config, self.tabpy_state = self._build_tabpy_state() - self.python_service = PythonServiceHandler(PythonService()) self.settings["compress_response"] = True self.settings[SettingsParameters.StaticPath] = os.path.abspath( self.settings[SettingsParameters.StaticPath] @@ -435,4 +421,11 @@ def _build_tabpy_state(self): logger.info(f"Loading state from state file {state_file_path}") tabpy_state = _get_state_from_file(state_file_dir) - return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings) + return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings, models=self.models) + + def _load_models(self): + logger.info("Loading models...") + for pkg in pkg_resources.iter_entry_points('tabpy_models'): + logger.info(f"Loading model '{pkg.name}'...") + self.models[pkg.name] = pkg.load() + logger.info(f"Loaded {len(self.models)} models") diff --git a/tabpy/tabpy_server/common/endpoint_file_mgr.py b/tabpy/tabpy_server/common/endpoint_file_mgr.py deleted file mode 100644 index 6b7fed00..00000000 --- a/tabpy/tabpy_server/common/endpoint_file_mgr.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -This module provides functionality required for managing endpoint objects in -TabPy. It provides a way to download endpoint files from remote -and then properly cleanup local the endpoint files on update/remove of endpoint -objects. - -The local temporary files for TabPy will by default located at - /tmp/query_objects - -""" -import logging -import os -import shutil -from re import compile as _compile - - -_name_checker = _compile(r"^[a-zA-Z0-9-_\s]+$") - - -def _check_endpoint_name(name, logger=logging.getLogger(__name__)): - """Checks that the endpoint name is valid by comparing it with an RE and - checking that it is not reserved.""" - if not isinstance(name, str): - msg = "Endpoint name must be a string" - logger.log(logging.CRITICAL, msg) - raise TypeError(msg) - - if name == "": - msg = "Endpoint name cannot be empty" - logger.log(logging.CRITICAL, msg) - raise ValueError(msg) - - if not _name_checker.match(name): - msg = ( - "Endpoint name can only contain: a-z, A-Z, 0-9," - " underscore, hyphens and spaces." - ) - logger.log(logging.CRITICAL, msg) - raise ValueError(msg) - - -def grab_files(directory): - """ - Generator that returns all files in a directory. - """ - if not os.path.isdir(directory): - return - else: - for name in os.listdir(directory): - full_path = os.path.join(directory, name) - if os.path.isdir(full_path): - for entry in grab_files(full_path): - yield entry - elif os.path.isfile(full_path): - yield full_path - - -def cleanup_endpoint_files( - name, query_path, logger=logging.getLogger(__name__), retain_versions=None -): - """ - Cleanup the disk space a certain endpiont uses. - - Parameters - ---------- - name : str - The endpoint name - - retain_version : int, optional - If given, then all files for this endpoint are removed except the - folder for the given version, otherwise, all files for that endpoint - are removed. - """ - _check_endpoint_name(name, logger=logger) - local_dir = os.path.join(query_path, name) - - # nothing to clean, this is true for state file path where we load - # Query Object directly from the state path instead of downloading - # to temporary location - if not os.path.exists(local_dir): - return - - if not retain_versions: - shutil.rmtree(local_dir) - else: - retain_folders = [ - os.path.join(local_dir, str(version)) for version in retain_versions - ] - logger.log(logging.INFO, f"Retain folders: {retain_folders}") - - for file_or_dir in os.listdir(local_dir): - candidate_dir = os.path.join(local_dir, file_or_dir) - if os.path.isdir(candidate_dir) and (candidate_dir not in retain_folders): - shutil.rmtree(candidate_dir) diff --git a/tabpy/tabpy_server/common/messages.py b/tabpy/tabpy_server/common/messages.py deleted file mode 100644 index ad684319..00000000 --- a/tabpy/tabpy_server/common/messages.py +++ /dev/null @@ -1,172 +0,0 @@ -import abc -from abc import ABCMeta -from collections import namedtuple -import json - - -class Msg: - """ - An abstract base class for all messages used for communicating between - the WebServices. - - The minimal functionality is the ability to instantiate a Msg from JSON - and to write a Msg instance to JSON. - - We use namedtuples because they are lightweight and immutable. The splat - operator (*) that we inherit from namedtuple is also convenient. We empty - __slots__ to avoid unnecessary overhead. - """ - - __metaclass__ = ABCMeta - - @abc.abstractmethod - def for_json(self): - d = self._asdict() - type_str = self.__class__.__name__ - d.update({"type": type_str}) - return d - - @abc.abstractmethod - def to_json(self): - return json.dumps(self.for_json()) - - @staticmethod - def from_json(str): - d = json.loads(str) - type_str = d["type"] - del d["type"] - return eval(type_str)(**d) - - -class LoadSuccessful( - namedtuple( - "LoadSuccessful", ["uri", "path", "version", "is_update", "endpoint_type"] - ), - Msg, -): - __slots__ = () - - -class LoadFailed(namedtuple("LoadFailed", ["uri", "version", "error_msg"]), Msg): - __slots__ = () - - -class LoadInProgress( - namedtuple( - "LoadInProgress", ["uri", "path", "version", "is_update", "endpoint_type"] - ), - Msg, -): - __slots__ = () - - -class Query(namedtuple("Query", ["uri", "params"]), Msg): - __slots__ = () - - -class QuerySuccessful( - namedtuple("QuerySuccessful", ["uri", "version", "response"]), Msg -): - __slots__ = () - - -class LoadObject( - namedtuple("LoadObject", ["uri", "url", "version", "is_update", "endpoint_type"]), - Msg, -): - __slots__ = () - - -class DeleteObjects(namedtuple("DeleteObjects", ["uris"]), Msg): - __slots__ = () - - -# Used for testing to flush out objects -class FlushObjects(namedtuple("FlushObjects", []), Msg): - __slots__ = () - - -class ObjectsDeleted(namedtuple("ObjectsDeleted", ["uris"]), Msg): - __slots__ = () - - -class ObjectsFlushed(namedtuple("ObjectsFlushed", ["n_before", "n_after"]), Msg): - __slots__ = () - - -class CountObjects(namedtuple("CountObjects", []), Msg): - __slots__ = () - - -class ObjectCount(namedtuple("ObjectCount", ["count"]), Msg): - __slots__ = () - - -class ListObjects(namedtuple("ListObjects", []), Msg): - __slots__ = () - - -class ObjectList(namedtuple("ObjectList", ["objects"]), Msg): - __slots__ = () - - -class UnknownURI(namedtuple("UnknownURI", ["uri"]), Msg): - __slots__ = () - - -class UnknownMessage(namedtuple("UnknownMessage", ["msg"]), Msg): - __slots__ = () - - -class DownloadSkipped( - namedtuple("DownloadSkipped", ["uri", "version", "msg", "host"]), Msg -): - __slots__ = () - - -class QueryFailed(namedtuple("QueryFailed", ["uri", "error"]), Msg): - __slots__ = () - - -class QueryError(namedtuple("QueryError", ["uri", "error"]), Msg): - __slots__ = () - - -class CheckHealth(namedtuple("CheckHealth", []), Msg): - __slots__ = () - - -class Healthy(namedtuple("Healthy", []), Msg): - __slots__ = () - - -class Unhealthy(namedtuple("Unhealthy", []), Msg): - __slots__ = () - - -class Ping(namedtuple("Ping", ["id"]), Msg): - __slots__ = () - - -class Pong(namedtuple("Pong", ["id"]), Msg): - __slots__ = () - - -class Listening(namedtuple("Listening", []), Msg): - __slots__ = () - - -class EngineFailure(namedtuple("EngineFailure", ["error"]), Msg): - __slots__ = () - - -class FlushLogs(namedtuple("FlushLogs", []), Msg): - __slots__ = () - - -class LogsFlushed(namedtuple("LogsFlushed", []), Msg): - __slots__ = () - - -class ServiceError(namedtuple("ServiceError", ["error"]), Msg): - __slots__ = () diff --git a/tabpy/tabpy_server/common/util.py b/tabpy/tabpy_server/common/util.py deleted file mode 100644 index c731450a..00000000 --- a/tabpy/tabpy_server/common/util.py +++ /dev/null @@ -1,3 +0,0 @@ -def format_exception(e, context): - err_msg = f"{e.__class__.__name__} : {str(e)}" - return err_msg diff --git a/tabpy/tabpy_server/handlers/__init__.py b/tabpy/tabpy_server/handlers/__init__.py index 0c00cde6..5978ae11 100644 --- a/tabpy/tabpy_server/handlers/__init__.py +++ b/tabpy/tabpy_server/handlers/__init__.py @@ -1,6 +1,5 @@ from tabpy.tabpy_server.handlers.base_handler import BaseHandler from tabpy.tabpy_server.handlers.main_handler import MainHandler -from tabpy.tabpy_server.handlers.management_handler import ManagementHandler from tabpy.tabpy_server.handlers.endpoint_handler import EndpointHandler from tabpy.tabpy_server.handlers.endpoints_handler import EndpointsHandler @@ -8,6 +7,3 @@ from tabpy.tabpy_server.handlers.query_plane_handler import QueryPlaneHandler from tabpy.tabpy_server.handlers.service_info_handler import ServiceInfoHandler from tabpy.tabpy_server.handlers.status_handler import StatusHandler -from tabpy.tabpy_server.handlers.upload_destination_handler import ( - UploadDestinationHandler, -) diff --git a/tabpy/tabpy_server/handlers/base_handler.py b/tabpy/tabpy_server/handlers/base_handler.py index dbbb6371..4b656566 100644 --- a/tabpy/tabpy_server/handlers/base_handler.py +++ b/tabpy/tabpy_server/handlers/base_handler.py @@ -1,6 +1,5 @@ import base64 import binascii -import concurrent import json import logging import tornado.web @@ -9,9 +8,6 @@ import uuid -STAGING_THREAD = concurrent.futures.ThreadPoolExecutor(max_workers=3) - - class ContextLoggerWrapper: """ This class appends request context to logged messages. diff --git a/tabpy/tabpy_server/handlers/endpoint_handler.py b/tabpy/tabpy_server/handlers/endpoint_handler.py index 022d8e0b..5be567ea 100644 --- a/tabpy/tabpy_server/handlers/endpoint_handler.py +++ b/tabpy/tabpy_server/handlers/endpoint_handler.py @@ -8,16 +8,10 @@ import json import logging -import shutil -from tabpy.tabpy_server.common.util import format_exception -from tabpy.tabpy_server.handlers import ManagementHandler -from tabpy.tabpy_server.handlers.base_handler import STAGING_THREAD -from tabpy.tabpy_server.management.state import get_query_object_path -from tabpy.tabpy_server.psws.callbacks import on_state_change -from tornado import gen +from tabpy.tabpy_server.handlers import MainHandler -class EndpointHandler(ManagementHandler): +class EndpointHandler(MainHandler): def initialize(self, app): super(EndpointHandler, self).initialize(app) @@ -26,116 +20,15 @@ def get(self, endpoint_name): self.fail_with_not_authorized() return - self.logger.log(logging.DEBUG, f"Processing GET for /endpoints/{endpoint_name}") + self.logger.log(logging.DEBUG, f"Processing GET for /endpoints/{endpoint_name}...") self._add_CORS_header() - if not endpoint_name: - self.write(json.dumps(self.tabpy_state.get_endpoints())) - else: - if endpoint_name in self.tabpy_state.get_endpoints(): - self.write(json.dumps(self.tabpy_state.get_endpoints()[endpoint_name])) - else: - self.error_out( - 404, - "Unknown endpoint", - info=f"Endpoint {endpoint_name} is not found", - ) - - @gen.coroutine - def put(self, name): - if self.should_fail_with_not_authorized(): - self.fail_with_not_authorized() - return - - self.logger.log(logging.DEBUG, f"Processing PUT for /endpoints/{name}") - try: - if not self.request.body: - self.error_out(400, "Input body cannot be empty") - self.finish() - return - try: - request_data = json.loads(self.request.body.decode("utf-8")) - except BaseException as ex: - self.error_out( - 400, log_message="Failed to decode input body", info=str(ex) - ) - self.finish() - return - - # check if endpoint exists - endpoints = self.tabpy_state.get_endpoints(name) - if len(endpoints) == 0: - self.error_out(404, f"endpoint {name} does not exist.") - self.finish() - return - - new_version = int(endpoints[name]["version"]) + 1 - self.logger.log(logging.INFO, f"Endpoint info: {request_data}") - err_msg = yield self._add_or_update_endpoint( - "update", name, new_version, request_data + if endpoint_name in self.app.models: + self.write(json.dumps(endpoint_name)) + else: + self.error_out( + 404, + f"Unknown endpoint {endpoint_name}", + info=f"Endpoint {endpoint_name} is not found", ) - if err_msg: - self.error_out(400, err_msg) - self.finish() - else: - self.write(self.tabpy_state.get_endpoints(name)) - self.finish() - - except Exception as e: - err_msg = format_exception(e, "update_endpoint") - self.error_out(500, err_msg) - self.finish() - - @gen.coroutine - def delete(self, name): - if self.should_fail_with_not_authorized(): - self.fail_with_not_authorized() - return - - self.logger.log(logging.DEBUG, f"Processing DELETE for /endpoints/{name}") - - try: - endpoints = self.tabpy_state.get_endpoints(name) - if len(endpoints) == 0: - self.error_out(404, f"endpoint {name} does not exist.") - self.finish() - return - - # update state - try: - endpoint_info = self.tabpy_state.delete_endpoint(name) - except Exception as e: - self.error_out(400, f"Error when removing endpoint: {e.message}") - self.finish() - return - - # delete files - if endpoint_info["type"] != "alias": - delete_path = get_query_object_path( - self.settings["state_file_path"], name, None - ) - try: - yield self._delete_po_future(delete_path) - except Exception as e: - self.error_out(400, f"Error while deleting: {e}") - self.finish() - return - - self.set_status(204) - self.finish() - - except Exception as e: - err_msg = format_exception(e, "delete endpoint") - self.error_out(500, err_msg) - self.finish() - - on_state_change( - self.settings, self.tabpy_state, self.python_service, self.logger - ) - - @gen.coroutine - def _delete_po_future(self, delete_path): - future = STAGING_THREAD.submit(shutil.rmtree, delete_path) - ret = yield future - raise gen.Return(ret) diff --git a/tabpy/tabpy_server/handlers/endpoints_handler.py b/tabpy/tabpy_server/handlers/endpoints_handler.py index 66132dd2..259756cb 100644 --- a/tabpy/tabpy_server/handlers/endpoints_handler.py +++ b/tabpy/tabpy_server/handlers/endpoints_handler.py @@ -7,13 +7,10 @@ """ import json -import logging -from tabpy.tabpy_server.common.util import format_exception -from tabpy.tabpy_server.handlers import ManagementHandler -from tornado import gen +from tabpy.tabpy_server.handlers import BaseHandler -class EndpointsHandler(ManagementHandler): +class EndpointsHandler(BaseHandler): def initialize(self, app): super(EndpointsHandler, self).initialize(app) @@ -24,52 +21,3 @@ def get(self): self._add_CORS_header() self.write(json.dumps(self.tabpy_state.get_endpoints())) - - @gen.coroutine - def post(self): - if self.should_fail_with_not_authorized(): - self.fail_with_not_authorized() - return - - try: - if not self.request.body: - self.error_out(400, "Input body cannot be empty") - self.finish() - return - - try: - request_data = json.loads(self.request.body.decode("utf-8")) - except Exception as ex: - self.error_out(400, "Failed to decode input body", str(ex)) - self.finish() - return - - if "name" not in request_data: - self.error_out(400, "name is required to add an endpoint.") - self.finish() - return - - name = request_data["name"] - - # check if endpoint already exist - if name in self.tabpy_state.get_endpoints(): - self.error_out(400, f"endpoint {name} already exists.") - self.finish() - return - - self.logger.log(logging.DEBUG, f'Adding endpoint "{name}"') - err_msg = yield self._add_or_update_endpoint("add", name, 1, request_data) - if err_msg: - self.error_out(400, err_msg) - else: - self.logger.log(logging.DEBUG, f"Endpoint {name} successfully added") - self.set_status(201) - self.write(self.tabpy_state.get_endpoints(name)) - self.finish() - return - - except Exception as e: - err_msg = format_exception(e, "/add_endpoint") - self.error_out(500, "error adding endpoint", err_msg) - self.finish() - return diff --git a/tabpy/tabpy_server/handlers/evaluation_plane_handler.py b/tabpy/tabpy_server/handlers/evaluation_plane_handler.py index 390aff04..21a8f837 100644 --- a/tabpy/tabpy_server/handlers/evaluation_plane_handler.py +++ b/tabpy/tabpy_server/handlers/evaluation_plane_handler.py @@ -2,10 +2,9 @@ import json import simplejson import logging -from tabpy.tabpy_server.common.util import format_exception import requests -from tornado import gen from datetime import timedelta +from tornado import gen class RestrictedTabPy: @@ -43,7 +42,7 @@ def initialize(self, executor, app): @gen.coroutine def _post_impl(self): body = json.loads(self.request.body.decode("utf-8")) - self.logger.log(logging.DEBUG, f"Processing POST request '{body}'...") + self.logger.log(logging.DEBUG, f"Processing POST /evaluate ...") if "script" not in body: self.error_out(400, "Script is empty.") return @@ -109,18 +108,7 @@ def post(self): yield self._post_impl() except Exception as e: err_msg = f"{e.__class__.__name__} : {str(e)}" - if err_msg != "KeyError : 'response'": - err_msg = format_exception(e, "POST /evaluate") - self.error_out(500, "Error processing script", info=err_msg) - else: - self.error_out( - 404, - "Error processing script", - info="The endpoint you're " - "trying to query did not respond. Please make sure the " - "endpoint exists and the correct set of arguments are " - "provided.", - ) + self.error_out(500, f"Error processing script: {err_msg}", info=err_msg) @gen.coroutine def _call_subprocess(self, function_to_evaluate, arguments): diff --git a/tabpy/tabpy_server/handlers/management_handler.py b/tabpy/tabpy_server/handlers/management_handler.py deleted file mode 100644 index 90e8a541..00000000 --- a/tabpy/tabpy_server/handlers/management_handler.py +++ /dev/null @@ -1,160 +0,0 @@ -import logging -import os -import shutil -from re import compile as _compile -from uuid import uuid4 as random_uuid - -from tornado import gen - -from tabpy.tabpy_server.app.SettingsParameters import SettingsParameters -from tabpy.tabpy_server.handlers import MainHandler -from tabpy.tabpy_server.handlers.base_handler import STAGING_THREAD -from tabpy.tabpy_server.management.state import get_query_object_path -from tabpy.tabpy_server.psws.callbacks import on_state_change - - -def copy_from_local(localpath, remotepath, is_dir=False): - if is_dir: - if not os.path.exists(remotepath): - # remote folder does not exist - shutil.copytree(localpath, remotepath) - else: - # remote folder exists, copy each file - src_files = os.listdir(localpath) - for file_name in src_files: - full_file_name = os.path.join(localpath, file_name) - if os.path.isdir(full_file_name): - # copy folder recursively - full_remote_path = os.path.join(remotepath, file_name) - shutil.copytree(full_file_name, full_remote_path) - else: - # copy each file - shutil.copy(full_file_name, remotepath) - else: - shutil.copy(localpath, remotepath) - - -class ManagementHandler(MainHandler): - def initialize(self, app): - super(ManagementHandler, self).initialize(app) - self.port = self.settings[SettingsParameters.Port] - - def _get_protocol(self): - return "http://" - - @gen.coroutine - def _add_or_update_endpoint(self, action, name, version, request_data): - """ - Add or update an endpoint - """ - self.logger.log(logging.DEBUG, f"Adding/updating model {name}...") - - _name_checker = _compile(r"^[a-zA-Z0-9-_\s]+$") - if not isinstance(name, str): - msg = "Endpoint name must be a string" - self.logger.log(logging.CRITICAL, msg) - raise TypeError(msg) - - if not _name_checker.match(name): - raise gen.Return( - "endpoint name can only contain: a-z, A-Z, 0-9," - " underscore, hyphens and spaces." - ) - - if self.settings.get("add_or_updating_endpoint"): - msg = ( - "Another endpoint update is already in progress" - ", please wait a while and try again" - ) - self.logger.log(logging.CRITICAL, msg) - raise RuntimeError(msg) - - request_uuid = random_uuid() - self.settings["add_or_updating_endpoint"] = request_uuid - try: - description = ( - request_data["description"] if "description" in request_data else None - ) - if "docstring" in request_data: - docstring = str( - bytes(request_data["docstring"], "utf-8").decode("unicode_escape") - ) - else: - docstring = None - endpoint_type = request_data["type"] if "type" in request_data else None - methods = request_data["methods"] if "methods" in request_data else [] - dependencies = ( - request_data["dependencies"] if "dependencies" in request_data else None - ) - target = request_data["target"] if "target" in request_data else None - schema = request_data["schema"] if "schema" in request_data else None - - src_path = request_data["src_path"] if "src_path" in request_data else None - target_path = get_query_object_path( - self.settings[SettingsParameters.StateFilePath], name, version - ) - self.logger.log(logging.DEBUG, f"Checking source path {src_path}...") - _path_checker = _compile(r"^[\\\:a-zA-Z0-9-_~\s/\.\(\)]+$") - # copy from staging - if src_path: - if not isinstance(request_data["src_path"], str): - raise gen.Return("src_path must be a string.") - if not _path_checker.match(src_path): - raise gen.Return( - "Endpoint source path name can only contain: " - "a-z, A-Z, 0-9, underscore, hyphens and spaces." - ) - - yield self._copy_po_future(src_path, target_path) - elif endpoint_type != "alias": - raise gen.Return("src_path is required to add/update an " "endpoint.") - - # alias special logic: - if endpoint_type == "alias": - if not target: - raise gen.Return("Target is required for alias endpoint.") - dependencies = [target] - - # update local config - try: - if action == "add": - self.tabpy_state.add_endpoint( - name=name, - description=description, - docstring=docstring, - endpoint_type=endpoint_type, - methods=methods, - dependencies=dependencies, - target=target, - schema=schema, - ) - else: - self.tabpy_state.update_endpoint( - name=name, - description=description, - docstring=docstring, - endpoint_type=endpoint_type, - methods=methods, - dependencies=dependencies, - target=target, - schema=schema, - version=version, - ) - - except Exception as e: - raise gen.Return(f"Error when changing TabPy state: {e}") - - on_state_change( - self.settings, self.tabpy_state, self.python_service, self.logger - ) - - finally: - self.settings["add_or_updating_endpoint"] = None - - @gen.coroutine - def _copy_po_future(self, src_path, target_path): - future = STAGING_THREAD.submit( - copy_from_local, src_path, target_path, is_dir=True - ) - ret = yield future - raise gen.Return(ret) diff --git a/tabpy/tabpy_server/handlers/query_plane_handler.py b/tabpy/tabpy_server/handlers/query_plane_handler.py index aab42593..1b7f3f70 100644 --- a/tabpy/tabpy_server/handlers/query_plane_handler.py +++ b/tabpy/tabpy_server/handlers/query_plane_handler.py @@ -1,16 +1,8 @@ from tabpy.tabpy_server.handlers import BaseHandler import logging import time -from tabpy.tabpy_server.common.messages import ( - Query, - QuerySuccessful, - QueryError, - UnknownURI, -) -from hashlib import md5 import uuid import json -from tabpy.tabpy_server.common.util import format_exception import urllib from tornado import gen @@ -24,44 +16,24 @@ class QueryPlaneHandler(BaseHandler): def initialize(self, app): super(QueryPlaneHandler, self).initialize(app) - def _query(self, po_name, data, uid, qry): + def _query(self, model, data): """ Parameters ---------- - po_name : str - The name of the query object to query + model : + Model (function) to execute data : dict The deserialized request body - uid: str - A unique identifier for the request - - qry: str - The incoming query object. This object maintains - raw incoming request, which is different from the sanitied data - Returns ------- - out : (result type, dict, int) - A triple containing a result type, the result message - as a dictionary, and the time in seconds that it took to complete - the request. + out : object + Result. """ - self.logger.log(logging.DEBUG, f"Collecting query info for {po_name}...") - start_time = time.time() - response = self.python_service.ps.query(po_name, data, uid) - gls_time = time.time() - start_time - self.logger.log(logging.DEBUG, f"Query info: {response}") - - if isinstance(response, QuerySuccessful): - response_json = response.to_json() - md5_tag = md5(response_json.encode("utf-8")).hexdigest() - self.set_header("Etag", f'"{md5_tag}"') - return (QuerySuccessful, response.for_json(), gls_time) - else: - self.logger.log(logging.ERROR, f"Failed query, response: {response}") - return (type(response), response.for_json(), gls_time) + response = model(data) + response_json = response.to_json() + return response_json # handle HTTP Options requests to support CORS # don't check API key (client does not send or receive data for OPTIONS, @@ -71,59 +43,32 @@ def options(self, pred_name): self.fail_with_not_authorized() return - self.logger.log(logging.DEBUG, f"Processing OPTIONS for /query/{pred_name}") + self.logger.log(logging.DEBUG, f"Processing OPTIONS for /query/{pred_name}...") # add CORS headers if TabPy has a cors_origin specified self._add_CORS_header() self.write({}) - def _handle_result(self, po_name, data, qry, uid): - (response_type, response, gls_time) = self._query(po_name, data, uid, qry) - - if response_type == QuerySuccessful: - result_dict = { - "response": response["response"], - "version": response["version"], - "model": po_name, - "uuid": uid, - } - self.write(result_dict) - self.finish() - return (gls_time, response["response"]) - else: - if response_type == UnknownURI: - self.error_out( - 404, - "UnknownURI", - info=( - "No query object has been registered" - f' with the name "{po_name}"' - ), - ) - elif response_type == QueryError: - self.error_out(400, "QueryError", info=response) - else: - self.error_out(500, "Error querying GLS", info=response) + def _handle_result(self, model, data): + response = self._query(model, data) - return (None, None) + result = { + "response": response["response"], + } + self.write(result) + self.finish() + return response["response"] def _sanitize_request_data(self, data): if not isinstance(data, dict): - msg = "Input data must be a dictionary" - self.logger.log(logging.CRITICAL, msg) - raise RuntimeError(msg) + raise RuntimeError("Input data must be a dictionary") - if "method" in data: - return {"data": data.get("data"), "method": data.get("method")} - elif "data" in data: + if "data" in data: return data.get("data") else: - msg = 'Input data must be a dictionary with a key called "data"' - self.logger.log(logging.CRITICAL, msg) - raise RuntimeError(msg) + raise RuntimeError('Input data must be a dictionary with a key called "data"') - def _process_query(self, endpoint_name, start): - self.logger.log(logging.DEBUG, f"Processing query {endpoint_name}...") + def _process_query(self, endpoint_name): try: self._add_CORS_header() @@ -136,89 +81,27 @@ def _process_query(self, endpoint_name, start): # Sanitize input data data = self._sanitize_request_data(json.loads(request_json)) except Exception as e: - self.logger.log(logging.ERROR, str(e)) - err_msg = format_exception(e, "Invalid Input Data") - self.error_out(400, err_msg) + msg = str(e) + self.logger.log(logging.ERROR, msg) + self.error_out(400, f"Invalid Input Data: {msg}") return try: - (po_name, _) = self._get_actual_model(endpoint_name) - - # po_name is None if self.python_service.ps.query_objects.get( - # endpoint_name) is None - if not po_name: + if endpoint_name not in self.app.models: self.error_out( 404, "UnknownURI", info=f'Endpoint "{endpoint_name}" does not exist' ) return - po_obj = self.python_service.ps.query_objects.get(po_name) - - if not po_obj: - self.error_out( - 404, "UnknownURI", info=f'Endpoint "{po_name}" does not exist' - ) - return - - if po_name != endpoint_name: - self.logger.log( - logging.INFO, f"Querying actual model: po_name={po_name}" - ) - - uid = _get_uuid() - - # record query w/ request ID in query log - qry = Query(po_name, request_json) - gls_time = 0 - # send a query to PythonService and return - (gls_time, _) = self._handle_result(po_name, data, qry, uid) - - # if error occurred, GLS time is None. - if not gls_time: - return + self.logger.log(logging.INFO, f"Executing model '{endpoint_name}'...") + model = self.app.models[endpoint_name] + self._handle_result(model, data) except Exception as e: - self.logger.log(logging.ERROR, str(e)) - err_msg = format_exception(e, "process query") - self.error_out(500, "Error processing query", info=err_msg) - return - - def _get_actual_model(self, endpoint_name): - # Find the actual query to run from given endpoint - all_endpoint_names = [] - - while True: - endpoint_info = self.python_service.ps.query_objects.get(endpoint_name) - if not endpoint_info: - return [None, None] - - all_endpoint_names.append(endpoint_name) - - endpoint_type = endpoint_info.get("type", "model") - - if endpoint_type == "alias": - endpoint_name = endpoint_info["endpoint_obj"] - elif endpoint_type == "model": - break - else: - self.error_out( - 500, - "Unknown endpoint type", - info=f'Endpoint type "{endpoint_type}" does not exist', - ) - return - - return (endpoint_name, all_endpoint_names) - - @gen.coroutine - def get(self, endpoint_name): - if self.should_fail_with_not_authorized(): - self.fail_with_not_authorized() - return - - start = time.time() - endpoint_name = urllib.parse.unquote(endpoint_name) - self._process_query(endpoint_name, start) + msg = str(e) + self.logger.log(logging.ERROR, msg) + err_msg = f"Error processing query: {msg}" + self.error_out(500, err_msg, info=err_msg) @gen.coroutine def post(self, endpoint_name): @@ -228,6 +111,5 @@ def post(self, endpoint_name): self.fail_with_not_authorized() return - start = time.time() endpoint_name = urllib.parse.unquote(endpoint_name) - self._process_query(endpoint_name, start) + self._process_query(endpoint_name) diff --git a/tabpy/tabpy_server/handlers/service_info_handler.py b/tabpy/tabpy_server/handlers/service_info_handler.py index 6b7060fb..630270c3 100644 --- a/tabpy/tabpy_server/handlers/service_info_handler.py +++ b/tabpy/tabpy_server/handlers/service_info_handler.py @@ -1,9 +1,9 @@ import json from tabpy.tabpy_server.app.SettingsParameters import SettingsParameters -from tabpy.tabpy_server.handlers import ManagementHandler +from tabpy.tabpy_server.handlers import BaseHandler -class ServiceInfoHandler(ManagementHandler): +class ServiceInfoHandler(BaseHandler): def initialize(self, app): super(ServiceInfoHandler, self).initialize(app) diff --git a/tabpy/tabpy_server/handlers/upload_destination_handler.py b/tabpy/tabpy_server/handlers/upload_destination_handler.py deleted file mode 100644 index 729aff3e..00000000 --- a/tabpy/tabpy_server/handlers/upload_destination_handler.py +++ /dev/null @@ -1,20 +0,0 @@ -from tabpy.tabpy_server.app.SettingsParameters import SettingsParameters -from tabpy.tabpy_server.handlers import ManagementHandler -import os - - -_QUERY_OBJECT_STAGING_FOLDER = "staging" - - -class UploadDestinationHandler(ManagementHandler): - def initialize(self, app): - super(UploadDestinationHandler, self).initialize(app) - - def get(self): - if self.should_fail_with_not_authorized(): - self.fail_with_not_authorized() - return - - path = self.settings[SettingsParameters.StateFilePath] - path = os.path.join(path, _QUERY_OBJECT_STAGING_FOLDER) - self.write({"path": path}) diff --git a/tabpy/tabpy_server/management/state.py b/tabpy/tabpy_server/management/state.py index c36f7710..4bed50f3 100644 --- a/tabpy/tabpy_server/management/state.py +++ b/tabpy/tabpy_server/management/state.py @@ -2,44 +2,11 @@ from ConfigParser import ConfigParser except ImportError: from configparser import ConfigParser -import json import logging -from tabpy.tabpy_server.management.util import write_state_config -from threading import Lock -from time import time -logger = logging.getLogger(__name__) - -# State File Config Section Names -_DEPLOYMENT_SECTION_NAME = "Query Objects Service Versions" -_QUERY_OBJECT_DOCSTRING = "Query Objects Docstrings" _SERVICE_INFO_SECTION_NAME = "Service Info" -_META_SECTION_NAME = "Meta" - -# Directory Names -_QUERY_OBJECT_DIR = "query_objects" - -""" -Lock to change the TabPy State. -""" -_PS_STATE_LOCK = Lock() - - -def state_lock(func): - """ - Mutex for changing PS state - """ - - def wrapper(self, *args, **kwargs): - try: - _PS_STATE_LOCK.acquire() - return func(self, *args, **kwargs) - finally: - # ALWAYS RELEASE LOCK - _PS_STATE_LOCK.release() - - return wrapper +logger = logging.getLogger(__name__) def _get_root_path(state_path): @@ -49,20 +16,6 @@ def _get_root_path(state_path): return state_path -def get_query_object_path(state_file_path, name, version): - """ - Returns the query object path - - If the version is None, a path without the version will be returned. - """ - root_path = _get_root_path(state_file_path) - if version is not None: - full_path = root_path + "/".join([_QUERY_OBJECT_DIR, name, str(version)]) - else: - full_path = root_path + "/".join([_QUERY_OBJECT_DIR, name]) - return full_path - - class TabPyState: """ The TabPy state object that stores attributes @@ -80,12 +33,12 @@ class TabPyState: """ - def __init__(self, settings, config=None): + def __init__(self, settings, models, config=None): self.settings = settings - self.set_config(config, _update=False) + self.models = models + self.set_config(config) - @state_lock - def set_config(self, config, logger=logging.getLogger(__name__), _update=True): + def set_config(self, config, logger=logging.getLogger(__name__)): """ Set the local ConfigParser manually. This new ConfigParser will be used as current state. @@ -93,309 +46,9 @@ def set_config(self, config, logger=logging.getLogger(__name__), _update=True): if not isinstance(config, ConfigParser): raise ValueError("Invalid config") self.config = config - if _update: - self._write_state(logger) - - def get_endpoints(self, name=None): - """ - Return a dictionary of endpoints - - Parameters - ---------- - name : str - The name of the endpoint. - If "name" is specified, only the information about that endpoint - will be returned. - - Returns - ------- - endpoints : dict - The dictionary containing information about each endpoint. - The keys are the endpoint names. - The values for each include: - - description - - doc string - - type - - target - - """ - endpoints = {} - try: - endpoint_names = self._get_config_value(_DEPLOYMENT_SECTION_NAME, name) - except Exception as e: - logger.error(f"error in get_endpoints: {str(e)}") - return {} - - if name: - endpoint_info = json.loads(endpoint_names) - docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name) - endpoint_info["docstring"] = str( - bytes(docstring, "utf-8").decode("unicode_escape") - ) - endpoints = {name: endpoint_info} - else: - for endpoint_name in endpoint_names: - endpoint_info = json.loads( - self._get_config_value(_DEPLOYMENT_SECTION_NAME, endpoint_name) - ) - docstring = self._get_config_value( - _QUERY_OBJECT_DOCSTRING, endpoint_name, True, "" - ) - endpoint_info["docstring"] = str( - bytes(docstring, "utf-8").decode("unicode_escape") - ) - endpoints[endpoint_name] = endpoint_info - logger.debug(f"Collected endpoints: {endpoints}") - return endpoints - - @state_lock - def add_endpoint( - self, - name, - description=None, - docstring=None, - endpoint_type=None, - methods=None, - target=None, - dependencies=None, - schema=None, - ): - """ - Add a new endpoint to the TabPy. - - Parameters - ---------- - name : str - Name of the endpoint - description : str, optional - Description of this endpoint - doc_string : str, optional - The doc string for this endpoint, if needed. - endpoint_type : str - The endpoint type (model, alias) - target : str, optional - The target endpoint name for the alias to be added. - - Note: - The version of this endpoint will be set to 1 since it is a new - endpoint. - - """ - try: - endpoints = self.get_endpoints() - if name is None or not isinstance(name, str) or len(name) == 0: - raise ValueError("name of the endpoint must be a valid string.") - elif name in endpoints: - raise ValueError(f"endpoint {name} already exists.") - if description and not isinstance(description, str): - raise ValueError("description must be a string.") - elif not description: - description = "" - if docstring and not isinstance(docstring, str): - raise ValueError("docstring must be a string.") - elif not docstring: - docstring = "-- no docstring found in query function --" - if not endpoint_type or not isinstance(endpoint_type, str): - raise ValueError("endpoint type must be a string.") - if dependencies and not isinstance(dependencies, list): - raise ValueError("dependencies must be a list.") - elif not dependencies: - dependencies = [] - if target and not isinstance(target, str): - raise ValueError("target must be a string.") - elif target and target not in endpoints: - raise ValueError("target endpoint is not valid.") - - endpoint_info = { - "description": description, - "docstring": docstring, - "type": endpoint_type, - "version": 1, - "dependencies": dependencies, - "target": target, - "creation_time": int(time()), - "last_modified_time": int(time()), - "schema": schema, - } - - endpoints[name] = endpoint_info - self._add_update_endpoints_config(endpoints) - except Exception as e: - logger.error(f"Error in add_endpoint: {e}") - raise - - def _add_update_endpoints_config(self, endpoints): - # save the endpoint info to config - dstring = "" - for endpoint_name in endpoints: - try: - info = endpoints[endpoint_name] - dstring = str( - bytes(info["docstring"], "utf-8").decode("unicode_escape") - ) - self._set_config_value( - _QUERY_OBJECT_DOCSTRING, - endpoint_name, - dstring, - _update_revision=False, - ) - del info["docstring"] - self._set_config_value( - _DEPLOYMENT_SECTION_NAME, endpoint_name, json.dumps(info) - ) - except Exception as e: - logger.error(f"Unable to write endpoints config: {e}") - raise - - @state_lock - def update_endpoint( - self, - name, - description=None, - docstring=None, - endpoint_type=None, - version=None, - methods=None, - target=None, - dependencies=None, - schema=None, - ): - """ - Update an existing endpoint on the TabPy. - - Parameters - ---------- - name : str - Name of the endpoint - description : str, optional - Description of this endpoint - doc_string : str, optional - The doc string for this endpoint, if needed. - endpoint_type : str, optional - The endpoint type (model, alias) - version : str, optional - The version of this endpoint - dependencies=[] - List of dependent endpoints for this existing endpoint - target : str, optional - The target endpoint name for the alias. - - Note: - For those parameters that are not specified, those values will not - get changed. - - """ - try: - endpoints = self.get_endpoints() - if not name or not isinstance(name, str): - raise ValueError("name of the endpoint must be string.") - elif name not in endpoints: - raise ValueError(f"endpoint {name} does not exist.") - - endpoint_info = endpoints[name] - - if description and not isinstance(description, str): - raise ValueError("description must be a string.") - elif not description: - description = endpoint_info["description"] - if docstring and not isinstance(docstring, str): - raise ValueError("docstring must be a string.") - elif not docstring: - docstring = endpoint_info["docstring"] - if endpoint_type and not isinstance(endpoint_type, str): - raise ValueError("endpoint type must be a string.") - elif not endpoint_type: - endpoint_type = endpoint_info["type"] - if version and not isinstance(version, int): - raise ValueError("version must be an int.") - elif not version: - version = endpoint_info["version"] - if dependencies and not isinstance(dependencies, list): - raise ValueError("dependencies must be a list.") - elif not dependencies: - if "dependencies" in endpoint_info: - dependencies = endpoint_info["dependencies"] - else: - dependencies = [] - if target and not isinstance(target, str): - raise ValueError("target must be a string.") - elif target and target not in endpoints: - raise ValueError("target endpoint is not valid.") - elif not target: - target = endpoint_info["target"] - endpoint_info = { - "description": description, - "docstring": docstring, - "type": endpoint_type, - "version": version, - "dependencies": dependencies, - "target": target, - "creation_time": endpoint_info["creation_time"], - "last_modified_time": int(time()), - "schema": schema, - } - - endpoints[name] = endpoint_info - self._add_update_endpoints_config(endpoints) - except Exception as e: - logger.error(f"Error in update_endpoint: {e}") - raise - - @state_lock - def delete_endpoint(self, name): - """ - Delete an existing endpoint on the TabPy - - Parameters - ---------- - name : str - The name of the endpoint to be deleted. - Returns - ------- - deleted endpoint object - - Note: - Cannot delete this endpoint if other endpoints are currently - depending on this endpoint. - - """ - if not name or name == "": - raise ValueError("Name of the endpoint must be a valid string.") - endpoints = self.get_endpoints() - if name not in endpoints: - raise ValueError(f"Endpoint {name} does not exist.") - - endpoint_to_delete = endpoints[name] - - # get dependencies and target - deps = set() - for endpoint_name in endpoints: - if endpoint_name != name: - deps_list = endpoints[endpoint_name].get("dependencies", []) - if name in deps_list: - deps.add(endpoint_name) - - # check if other endpoints are depending on this endpoint - if len(deps) > 0: - raise ValueError( - f"Cannot remove endpoint {name}, it is currently " - f"used by {list(deps)} endpoints." - ) - - del endpoints[name] - - # delete the endpoint from state - try: - self._remove_config_option( - _QUERY_OBJECT_DOCSTRING, name, _update_revision=False - ) - self._remove_config_option(_DEPLOYMENT_SECTION_NAME, name) - - return endpoint_to_delete - except Exception as e: - logger.error(f"Unable to delete endpoint {e}") - raise ValueError(f"Unable to delete endpoint: {e}") + def get_endpoints(self): + return self.models.keys() @property def name(self): @@ -423,23 +76,6 @@ def creation_time(self): logger.error(f"Unable to get name: {e}") return creation_time - @state_lock - def set_name(self, name): - """ - Set the name of this TabPy service. - - Parameters - ---------- - name : str - Name of TabPy service. - """ - if not isinstance(name, str): - raise ValueError("name must be a string.") - try: - self._set_config_value(_SERVICE_INFO_SECTION_NAME, "Name", name) - except Exception as e: - logger.error(f"Unable to set name: {e}") - def get_description(self): """ Returns the description of the TabPy service. @@ -453,36 +89,6 @@ def get_description(self): logger.error(f"Unable to get description: {e}") return description - @state_lock - def set_description(self, description): - """ - Set the description of this TabPy service. - - Parameters - ---------- - description : str - Description of TabPy service. - """ - if not isinstance(description, str): - raise ValueError("Description must be a string.") - try: - self._set_config_value( - _SERVICE_INFO_SECTION_NAME, "Description", description - ) - except Exception as e: - logger.error(f"Unable to set description: {e}") - - def get_revision_number(self): - """ - Returns the revision number of this TabPy service. - """ - rev = -1 - try: - rev = int(self._get_config_value(_META_SECTION_NAME, "Revision Number")) - except Exception as e: - logger.error(f"Unable to get revision number: {e}") - return rev - def get_access_control_allow_origin(self): """ Returns Access-Control-Allow-Origin of this TabPy service. @@ -523,71 +129,6 @@ def get_access_control_allow_methods(self): pass return _cors_methods - def _set_revision_number(self, revision_number): - """ - Set the revision number of this TabPy service. - """ - if not isinstance(revision_number, int): - raise ValueError("revision number must be an int.") - try: - self._set_config_value( - _META_SECTION_NAME, "Revision Number", revision_number - ) - except Exception as e: - logger.error(f"Unable to set revision number: {e}") - - def _remove_config_option( - self, - section_name, - option_name, - logger=logging.getLogger(__name__), - _update_revision=True, - ): - if not self.config: - raise ValueError("State configuration not yet loaded.") - self.config.remove_option(section_name, option_name) - # update revision number - if _update_revision: - self._increase_revision_number() - self._write_state(logger=logger) - - def _has_config_value(self, section_name, option_name): - if not self.config: - raise ValueError("State configuration not yet loaded.") - return self.config.has_option(section_name, option_name) - - def _increase_revision_number(self): - if not self.config: - raise ValueError("State configuration not yet loaded.") - cur_rev = int(self.config.get(_META_SECTION_NAME, "Revision Number")) - self.config.set(_META_SECTION_NAME, "Revision Number", str(cur_rev + 1)) - - def _set_config_value( - self, - section_name, - option_name, - option_value, - logger=logging.getLogger(__name__), - _update_revision=True, - ): - if not self.config: - raise ValueError("State configuration not yet loaded.") - - if not self.config.has_section(section_name): - logger.log(logging.DEBUG, f"Adding config section {section_name}") - self.config.add_section(section_name) - - self.config.set(section_name, option_name, option_value) - # update revision number - if _update_revision: - self._increase_revision_number() - self._write_state(logger=logger) - - def _get_config_items(self, section_name): - if not self.config: - raise ValueError("State configuration not yet loaded.") - return self.config.items(section_name) - def _get_config_value( self, section_name, option_name, optional=False, default_value=None ): @@ -615,10 +156,3 @@ def _get_config_value( logger.log(logging.DEBUG, f"Returning value '{res}'") return res - - def _write_state(self, logger=logging.getLogger(__name__)): - """ - Write state (ConfigParser) to Consul - """ - logger.log(logging.INFO, "Writing state to config") - write_state_config(self.config, self.settings, logger=logger) diff --git a/tabpy/tabpy_server/psws/__init__.py b/tabpy/tabpy_server/psws/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tabpy/tabpy_server/psws/callbacks.py b/tabpy/tabpy_server/psws/callbacks.py deleted file mode 100644 index 4b1fe14e..00000000 --- a/tabpy/tabpy_server/psws/callbacks.py +++ /dev/null @@ -1,205 +0,0 @@ -import logging -from tabpy.tabpy_server.app.SettingsParameters import SettingsParameters -from tabpy.tabpy_server.common.messages import ( - LoadObject, - DeleteObjects, - ListObjects, - ObjectList, -) -from tabpy.tabpy_server.common.endpoint_file_mgr import cleanup_endpoint_files -from tabpy.tabpy_server.common.util import format_exception -from tabpy.tabpy_server.management.state import TabPyState, get_query_object_path -from tabpy.tabpy_server.management import util -from time import sleep -from tornado import gen - - -logger = logging.getLogger(__name__) - - -def wait_for_endpoint_loaded(python_service, object_uri): - """ - This method waits for the object to be loaded. - """ - logger.info("Waiting for object to be loaded...") - while True: - msg = ListObjects() - list_object_msg = python_service.manage_request(msg) - if not isinstance(list_object_msg, ObjectList): - logger.error(f"Error loading endpoint {object_uri}: {list_object_msg}") - return - - for (uri, info) in list_object_msg.objects.items(): - if uri == object_uri: - if info["status"] != "LoadInProgress": - logger.info(f'Object load status: {info["status"]}') - return - - sleep(0.1) - - -@gen.coroutine -def init_ps_server(settings, tabpy_state): - logger.info("Initializing TabPy Server...") - existing_pos = tabpy_state.get_endpoints() - for (object_name, obj_info) in existing_pos.items(): - try: - object_version = obj_info["version"] - get_query_object_path( - settings[SettingsParameters.StateFilePath], object_name, object_version - ) - except Exception as e: - logger.error( - f"Exception encounted when downloading object: {object_name}" - f", error: {e}" - ) - - -@gen.coroutine -def init_model_evaluator(settings, tabpy_state, python_service): - """ - This will go through all models that the service currently have and - initialize them. - """ - logger.info("Initializing models...") - - existing_pos = tabpy_state.get_endpoints() - - for (object_name, obj_info) in existing_pos.items(): - object_version = obj_info["version"] - object_type = obj_info["type"] - object_path = get_query_object_path( - settings[SettingsParameters.StateFilePath], object_name, object_version - ) - - logger.info( - f"Load endpoint: {object_name}, " - f"version: {object_version}, " - f"type: {object_type}" - ) - if object_type == "alias": - msg = LoadObject( - object_name, obj_info["target"], object_version, False, "alias" - ) - else: - local_path = object_path - msg = LoadObject( - object_name, local_path, object_version, False, object_type - ) - python_service.manage_request(msg) - - -def _get_latest_service_state(settings, tabpy_state, new_ps_state, python_service): - """ - Update the endpoints from the latest remote state file. - - Returns - -------- - (has_changes, endpoint_diff): - has_changes: True or False - endpoint_diff: Summary of what has changed, one entry for each changes - """ - # Shortcut when nothing is changed - changes = {"endpoints": {}} - - # update endpoints - new_endpoints = new_ps_state.get_endpoints() - diff = {} - current_endpoints = python_service.ps.query_objects - for (endpoint_name, endpoint_info) in new_endpoints.items(): - existing_endpoint = current_endpoints.get(endpoint_name) - if (existing_endpoint is None) or endpoint_info["version"] != existing_endpoint[ - "version" - ]: - # Either a new endpoint or new endpoint version - path_to_new_version = get_query_object_path( - settings[SettingsParameters.StateFilePath], - endpoint_name, - endpoint_info["version"], - ) - endpoint_type = endpoint_info.get("type", "model") - diff[endpoint_name] = ( - endpoint_type, - endpoint_info["version"], - path_to_new_version, - ) - - # add removed models too - for (endpoint_name, endpoint_info) in current_endpoints.items(): - if endpoint_name not in new_endpoints.keys(): - endpoint_type = current_endpoints[endpoint_name].get("type", "model") - diff[endpoint_name] = (endpoint_type, None, None) - - if diff: - changes["endpoints"] = diff - - return (True, changes) - - -@gen.coroutine -def on_state_change( - settings, tabpy_state, python_service, logger=logging.getLogger(__name__) -): - try: - logger.log(logging.INFO, "Loading state from state file") - config = util._get_state_from_file( - settings[SettingsParameters.StateFilePath], logger=logger - ) - new_ps_state = TabPyState(config=config, settings=settings) - - (has_changes, changes) = _get_latest_service_state( - settings, tabpy_state, new_ps_state, python_service - ) - if not has_changes: - logger.info("Nothing changed, return.") - return - - new_endpoints = new_ps_state.get_endpoints() - for object_name in changes["endpoints"]: - (object_type, object_version, object_path) = changes["endpoints"][ - object_name - ] - - if not object_path and not object_version: # removal - logger.info(f"Removing object: URI={object_name}") - - python_service.manage_request(DeleteObjects([object_name])) - - cleanup_endpoint_files( - object_name, settings[SettingsParameters.UploadDir], logger=logger - ) - - else: - endpoint_info = new_endpoints[object_name] - is_update = object_version > 1 - if object_type == "alias": - msg = LoadObject( - object_name, - endpoint_info["target"], - object_version, - is_update, - "alias", - ) - else: - local_path = object_path - msg = LoadObject( - object_name, local_path, object_version, is_update, object_type - ) - - python_service.manage_request(msg) - wait_for_endpoint_loaded(python_service, object_name) - - # cleanup old version of endpoint files - if object_version > 2: - cleanup_endpoint_files( - object_name, - settings[SettingsParameters.UploadDir], - logger=logger, - retain_versions=[object_version, object_version - 1], - ) - - except Exception as e: - err_msg = format_exception(e, "on_state_change") - logger.log( - logging.ERROR, f"Error submitting update model request: error={err_msg}" - ) diff --git a/tabpy/tabpy_server/psws/python_service.py b/tabpy/tabpy_server/psws/python_service.py deleted file mode 100644 index 5352b40b..00000000 --- a/tabpy/tabpy_server/psws/python_service.py +++ /dev/null @@ -1,275 +0,0 @@ -import concurrent.futures -import logging -from tabpy.tabpy_tools.query_object import QueryObject -from tabpy.tabpy_server.common.util import format_exception -from tabpy.tabpy_server.common.messages import ( - LoadObject, - DeleteObjects, - FlushObjects, - CountObjects, - ListObjects, - UnknownMessage, - LoadFailed, - ObjectsDeleted, - ObjectsFlushed, - QueryFailed, - QuerySuccessful, - UnknownURI, - DownloadSkipped, - LoadInProgress, - ObjectCount, - ObjectList, -) - - -logger = logging.getLogger(__name__) - - -class PythonServiceHandler: - """ - A wrapper around PythonService object that receives requests and calls the - corresponding methods. - """ - - def __init__(self, ps): - self.ps = ps - - def manage_request(self, msg): - try: - logger.debug(f"Received request {type(msg).__name__}") - if isinstance(msg, LoadObject): - response = self.ps.load_object(*msg) - elif isinstance(msg, DeleteObjects): - response = self.ps.delete_objects(msg.uris) - elif isinstance(msg, FlushObjects): - response = self.ps.flush_objects() - elif isinstance(msg, CountObjects): - response = self.ps.count_objects() - elif isinstance(msg, ListObjects): - response = self.ps.list_objects() - else: - response = UnknownMessage(msg) - - logger.debug(f"Returning response {response}") - return response - except Exception as e: - logger.exception(e) - msg = e - if hasattr(e, "message"): - msg = e.message - logger.error(f"Error processing request: {msg}") - return UnknownMessage(msg) - - -class PythonService: - """ - This class is a simple wrapper maintaining loaded query objects from - the current TabPy instance. `query_objects` is a dictionary that - maps query object URI to query objects - - The query_objects schema is as follow: - - {'version': , - 'last_error':, - 'endpoint_obj':, - 'type':, - 'status':} - - """ - - def __init__(self, query_objects=None): - - self.EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1) - self.query_objects = query_objects or {} - - def _load_object( - self, object_uri, object_url, object_version, is_update, object_type - ): - try: - logger.info( - f"Loading object:, URI={object_uri}, " - f"URL={object_url}, version={object_version}, " - f"is_updated={is_update}" - ) - if object_type == "model": - po = QueryObject.load(object_url) - elif object_type == "alias": - po = object_url - else: - raise RuntimeError(f"Unknown object type: {object_type}") - - self.query_objects[object_uri] = { - "version": object_version, - "type": object_type, - "endpoint_obj": po, - "status": "LoadSuccessful", - "last_error": None, - } - except Exception as e: - logger.exception(e) - logger.error( - f"Unable to load QueryObject: path={object_url}, " f"error={str(e)}" - ) - - self.query_objects[object_uri] = { - "version": object_version, - "type": object_type, - "endpoint_obj": None, - "status": "LoadFailed", - "last_error": f"Load failed: {str(e)}", - } - - def load_object( - self, object_uri, object_url, object_version, is_update, object_type - ): - try: - obj_info = self.query_objects.get(object_uri) - if ( - obj_info - and obj_info["endpoint_obj"] - and (obj_info["version"] >= object_version) - ): - logger.info("Received load message for object already loaded") - - return DownloadSkipped( - object_uri, - obj_info["version"], - "Object with greater " "or equal version already loaded", - ) - else: - if object_uri not in self.query_objects: - self.query_objects[object_uri] = { - "version": object_version, - "type": object_type, - "endpoint_obj": None, - "status": "LoadInProgress", - "last_error": None, - } - else: - self.query_objects[object_uri]["status"] = "LoadInProgress" - - self.EXECUTOR.submit( - self._load_object, - object_uri, - object_url, - object_version, - is_update, - object_type, - ) - - return LoadInProgress( - object_uri, object_url, object_version, is_update, object_type - ) - except Exception as e: - logger.exception(e) - logger.error( - f"Unable to load QueryObject: path={object_url}, " f"error={str(e)}" - ) - - self.query_objects[object_uri] = { - "version": object_version, - "type": object_type, - "endpoint_obj": None, - "status": "LoadFailed", - "last_error": str(e), - } - - return LoadFailed(object_uri, object_version, str(e)) - - def delete_objects(self, object_uris): - """Delete one or more objects from the query_objects map""" - if isinstance(object_uris, list): - deleted = [] - for uri in object_uris: - deleted.extend(self.delete_objects(uri).uris) - return ObjectsDeleted(deleted) - elif isinstance(object_uris, str): - deleted_obj = self.query_objects.pop(object_uris, None) - if deleted_obj: - return ObjectsDeleted([object_uris]) - else: - logger.warning( - f"Received message to delete query object " - f"that doesn't exist: " - f"object_uris={object_uris}" - ) - return ObjectsDeleted([]) - else: - logger.error( - f"Unexpected input to delete objects: input={object_uris}, " - f'info="Input should be list or str. ' - f'Type: {type(object_uris)}"' - ) - return ObjectsDeleted([]) - - def flush_objects(self): - """Flush objects from the query_objects map""" - logger.debug("Flushing query objects") - n = len(self.query_objects) - self.query_objects.clear() - return ObjectsFlushed(n, 0) - - def count_objects(self): - """Count the number of Loaded QueryObjects stored in memory""" - count = 0 - for uri, po in self.query_objects.items(): - if po["endpoint_obj"] is not None: - count += 1 - return ObjectCount(count) - - def list_objects(self): - """List the objects as (URI, version) pairs""" - - objects = {} - for (uri, obj_info) in self.query_objects.items(): - objects[uri] = { - "version": obj_info["version"], - "type": obj_info["type"], - "status": obj_info["status"], - "reason": obj_info["last_error"], - } - - return ObjectList(objects) - - def query(self, object_uri, params, uid): - """Execute a QueryObject query""" - logger.debug(f"Querying Python service {object_uri}...") - try: - if not isinstance(params, dict) and not isinstance(params, list): - return QueryFailed( - uri=object_uri, - error=( - "Query parameter needs to be a dictionary or a list" - f". Given value is of type {type(params)}" - ), - ) - - obj_info = self.query_objects.get(object_uri) - logger.debug(f"Found object {obj_info}") - if obj_info: - pred_obj = obj_info["endpoint_obj"] - version = obj_info["version"] - - if not pred_obj: - return QueryFailed( - uri=object_uri, - error=( - "There is no query object associated to the " - f"endpoint: {object_uri}" - ), - ) - - logger.debug(f"Querying endpoint with params ({params})...") - if isinstance(params, dict): - result = pred_obj.query(**params) - else: - result = pred_obj.query(*params) - - return QuerySuccessful(object_uri, version, result) - else: - return UnknownURI(object_uri) - except Exception as e: - logger.exception(e) - err_msg = format_exception(e, "/query") - logger.error(err_msg) - return QueryFailed(uri=object_uri, error=err_msg) diff --git a/tabpy/tabpy_tools/__init__.py b/tabpy/tabpy_tools/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/tabpy/tabpy_tools/client.py b/tabpy/tabpy_tools/client.py deleted file mode 100644 index 96c2c267..00000000 --- a/tabpy/tabpy_tools/client.py +++ /dev/null @@ -1,395 +0,0 @@ -import copy -from re import compile -import time -import requests - -from .rest import RequestsNetworkWrapper, ServiceClient - -from .rest_client import RESTServiceClient, Endpoint - -from .custom_query_object import CustomQueryObject -import os -import logging - -logger = logging.getLogger(__name__) - -_name_checker = compile(r"^[\w -]+$") - - -def _check_endpoint_type(name): - if not isinstance(name, str): - raise TypeError("Endpoint name must be a string") - - if name == "": - raise ValueError("Endpoint name cannot be empty") - - -def _check_hostname(name): - _check_endpoint_type(name) - hostname_checker = compile(r"^^http(s)?://[\w.-]+(/)?(:\d+)?(/)?$") - - if not hostname_checker.match(name): - raise ValueError( - f"endpoint name {name} should be in http(s)://" - "[:] and hostname may consist only of: " - "a-z, A-Z, 0-9, underscore and hyphens." - ) - - -def _check_endpoint_name(name): - """Checks that the endpoint name is valid by comparing it with an RE and - checking that it is not reserved.""" - _check_endpoint_type(name) - - if not _name_checker.match(name): - raise ValueError( - f"endpoint name {name} can only contain: a-z, A-Z, 0-9," - " underscore, hyphens and spaces." - ) - - -class Client: - def __init__(self, endpoint, query_timeout=1000): - """ - Connects to a running server. - - The class constructor takes a server address which is then used to - connect for all subsequent member APIs. - - Parameters - ---------- - endpoint : str, optional - The server URL. - - query_timeout : float, optional - The timeout for query operations. - """ - _check_hostname(endpoint) - - self._endpoint = endpoint - - session = requests.session() - session.verify = False - requests.packages.urllib3.disable_warnings() - - # Setup the communications layer. - network_wrapper = RequestsNetworkWrapper(session) - service_client = ServiceClient(self._endpoint, network_wrapper) - - self._service = RESTServiceClient(service_client) - if type(query_timeout) in (int, float) and query_timeout > 0: - self._service.query_timeout = query_timeout - else: - self._service.query_timeout = 0.0 - - def __repr__(self): - return ( - "<" - + self.__class__.__name__ - + " object at " - + hex(id(self)) - + " connected to " - + repr(self._endpoint) - + ">" - ) - - def get_status(self): - """ - Gets the status of the deployed endpoints. - - Returns - ------- - dict - Keys are endpoints and values are dicts describing the state of - the endpoint. - - Examples - -------- - .. sourcecode:: python - { - u'foo': { - u'status': u'LoadFailed', - u'last_error': u'error mesasge', - u'version': 1, - u'type': u'model', - }, - } - """ - return self._service.get_status() - - # - # Query - # - - @property - def query_timeout(self): - """The timeout for queries in milliseconds.""" - return self._service.query_timeout - - @query_timeout.setter - def query_timeout(self, value): - if type(value) in (int, float) and value > 0: - self._service.query_timeout = value - - def query(self, name, *args, **kwargs): - """Query an endpoint. - - Parameters - ---------- - name : str - The name of the endpoint. - - *args : list of anything - Ordered parameters to the endpoint. - - **kwargs : dict of anything - Named parameters to the endpoint. - - Returns - ------- - dict - Keys are: - model: the name of the endpoint - version: the version used. - response: the response to the query. - uuid : a unique id for the request. - """ - return self._service.query(name, *args, **kwargs) - - # - # Endpoints - # - - def get_endpoints(self, type=None): - """Returns all deployed endpoints. - - Examples - -------- - .. sourcecode:: python - {"clustering": - {"description": "", - "docstring": "-- no docstring found in query function --", - "creation_time": 1469511182, - "version": 1, - "dependencies": [], - "last_modified_time": 1469511182, - "type": "model", - "target": null}, - "add": { - "description": "", - "docstring": "-- no docstring found in query function --", - "creation_time": 1469505967, - "version": 1, - "dependencies": [], - "last_modified_time": 1469505967, - "type": "model", - "target": null} - } - """ - return self._service.get_endpoints(type) - - def _get_endpoint_upload_destination(self): - """Returns the endpoint upload destination.""" - return self._service.get_endpoint_upload_destination()["path"] - - def deploy(self, name, obj, description="", schema=None, override=False): - """Deploys a Python function as an endpoint in the server. - - Parameters - ---------- - name : str - A unique identifier for the endpoint. - - obj : function - Refers to a user-defined function with any signature. However both - input and output of the function need to be JSON serializable. - - description : str, optional - The description for the endpoint. This string will be returned by - the ``endpoints`` API. - - schema : dict, optional - The schema of the function, containing information about input and - output parameters, and respective examples. Providing a schema for - a deployed function lets other users of the service discover how to - use it. Refer to schema.generate_schema for more information on - how to generate the schema. - - override : bool - Whether to override (update) an existing endpoint. If False and - there is already an endpoint with that name, it will raise a - RuntimeError. If True and there is already an endpoint with that - name, it will deploy a new version on top of it. - - See Also - -------- - remove, get_endpoints - """ - endpoint = self.get_endpoints().get(name) - if endpoint: - if not override: - raise RuntimeError( - f"An endpoint with that name ({name}) already" - ' exists. Use "override = True" to force update ' - "an existing endpoint." - ) - - version = endpoint.version + 1 - else: - version = 1 - - obj = self._gen_endpoint(name, obj, description, version, schema) - - self._upload_endpoint(obj) - - if version == 1: - self._service.add_endpoint(Endpoint(**obj)) - else: - self._service.set_endpoint(Endpoint(**obj)) - - self._wait_for_endpoint_deployment(obj["name"], obj["version"]) - - def remove(self, name): - '''Removes an endpoint dict. - - Parameters - ---------- - name : str - Endpoint name to remove''' - self._service.remove_endpoint(name) - - def _gen_endpoint(self, name, obj, description, version=1, schema=None): - """Generates an endpoint dict. - - Parameters - ---------- - name : str - Endpoint name to add or update - - obj : func - Object that backs the endpoint. See add() for a complete - description. - - description : str - Description of the endpoint - - version : int - The version. Defaults to 1. - - Returns - ------- - dict - Keys: - name : str - The name provided. - - version : int - The version provided. - - description : str - The provided description. - - type : str - The type of the endpoint. - - endpoint_obj : object - The wrapper around the obj provided that can be used to - generate the code and dependencies for the endpoint. - - Raises - ------ - TypeError - When obj is not one of the expected types. - """ - # check for invalid PO names - _check_endpoint_name(name) - - if description is None: - if isinstance(obj.__doc__, str): - # extract doc string - description = obj.__doc__.strip() or "" - else: - description = "" - - endpoint_object = CustomQueryObject(query=obj, description=description,) - - return { - "name": name, - "version": version, - "description": description, - "type": "model", - "endpoint_obj": endpoint_object, - "dependencies": endpoint_object.get_dependencies(), - "methods": endpoint_object.get_methods(), - "required_files": [], - "required_packages": [], - "schema": copy.copy(schema), - } - - def _upload_endpoint(self, obj): - """Sends the endpoint across the wire.""" - endpoint_obj = obj["endpoint_obj"] - - dest_path = self._get_endpoint_upload_destination() - - # Upload the endpoint - obj["src_path"] = os.path.join( - dest_path, "endpoints", obj["name"], str(obj["version"]) - ) - - endpoint_obj.save(obj["src_path"]) - - def _wait_for_endpoint_deployment( - self, endpoint_name, version=1, interval=1.0, - ): - """ - Waits for the endpoint to be deployed by calling get_status() and - checking the versions deployed of the endpoint against the expected - version. If all the versions are equal to or greater than the version - expected, then it will return. Uses time.sleep(). - """ - logger.info( - f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}" - ) - start = time.time() - while True: - ep_status = self.get_status() - try: - ep = ep_status[endpoint_name] - except KeyError: - logger.info( - f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet" - ) - else: - logger.info(f"ep={ep}") - - if ep["status"] == "LoadFailed": - raise RuntimeError(f'LoadFailed: {ep["last_error"]}') - - elif ep["status"] == "LoadSuccessful": - if ep["version"] >= version: - logger.info("LoadSuccessful") - break - else: - logger.info("LoadSuccessful but wrong version") - - if time.time() - start > 10: - raise RuntimeError("Waited more then 10s for deployment") - - logger.info(f"Sleeping {interval}...") - time.sleep(interval) - - def set_credentials(self, username, password): - """ - Set credentials for all the TabPy client-server communication - where client is tabpy-tools and server is tabpy-server. - - Parameters - ---------- - username : str - User name (login). Username is case insensitive. - - password : str - Password in plain text. - """ - self._service.set_credentials(username, password) diff --git a/tabpy/tabpy_tools/custom_query_object.py b/tabpy/tabpy_tools/custom_query_object.py deleted file mode 100755 index 18a149b8..00000000 --- a/tabpy/tabpy_tools/custom_query_object.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from .query_object import QueryObject as _QueryObject - - -logger = logging.getLogger(__name__) - - -class CustomQueryObject(_QueryObject): - def __init__(self, query, description=""): - """Create a new CustomQueryObject. - - Parameters - ----------- - - query : function - Function that defines a custom query method. The query can have any - signature, but input and output of the query needs to be JSON - serializable. - - description : str - The description of the custom query object - - """ - super().__init__(description) - - self.custom_query = query - - def query(self, *args, **kwargs): - """Query the custom defined query method using the given input. - - Parameters - ---------- - args : list - positional arguments to the query - - kwargs : dict - keyword arguments to the query - - Returns - ------- - out: object. - The results depends on the implementation of the query method. - Typically the return value will be whatever that function returns. - - See Also - -------- - QueryObject - """ - # include the dependent files in sys path so that the query can run - # correctly - - try: - logger.debug( - "Running custom query with arguments " f"({args}, {kwargs})..." - ) - ret = self.custom_query(*args, **kwargs) - except Exception as e: - logger.exception( - "Exception hit when running custom query, error: " f"{str(e)}" - ) - raise - - logger.debug(f"Received response {ret}") - try: - return self._make_serializable(ret) - except Exception as e: - logger.exception( - "Cannot properly serialize custom query result, " f"error: {str(e)}" - ) - raise - - def get_doc_string(self): - """Get doc string from customized query""" - if self.custom_query.__doc__ is not None: - return self.custom_query.__doc__ - else: - return "-- no docstring found in query function --" - - def get_methods(self): - return [self.get_query_method()] - - def get_query_method(self): - return {"method": "query"} diff --git a/tabpy/tabpy_tools/query_object.py b/tabpy/tabpy_tools/query_object.py deleted file mode 100755 index 5ccbc109..00000000 --- a/tabpy/tabpy_tools/query_object.py +++ /dev/null @@ -1,108 +0,0 @@ -import abc -import logging -import os -import json -import shutil - -import cloudpickle as _cloudpickle - - -logger = logging.getLogger(__name__) - - -class QueryObject(abc.ABC): - """ - Derived class needs to implement the following interface: - * query() -- given input, return query result - * get_doc_string() -- returns documentation for the Query Object - """ - - def __init__(self, description=""): - self.description = description - - def get_dependencies(self): - """All endpoints this endpoint depends on""" - return [] - - @abc.abstractmethod - def query(self, input): - """execute query on the provided input""" - pass - - @abc.abstractmethod - def get_doc_string(self): - """Returns documentation for the query object - - By default, this method returns the docstring for 'query' method - Derived class may overwrite this method to dynamically create docstring - """ - pass - - def save(self, path): - """ Save query object to the given local path - - Parameters - ---------- - path : str - The location to save the query object to - """ - if os.path.exists(path): - logger.warning( - f'Overwriting existing file "{path}" when saving query object' - ) - rm_fn = os.remove if os.path.isfile(path) else shutil.rmtree - rm_fn(path) - self._save_local(path) - - def _save_local(self, path): - """Save current query object to local path - """ - try: - os.makedirs(path) - except OSError as e: - import errno - - if e.errno == errno.EEXIST and os.path.isdir(path): - pass - else: - raise - - with open(os.path.join(path, "pickle_archive"), "wb") as f: - _cloudpickle.dump(self, f) - - @classmethod - def load(cls, path): - """ Load query object from given path - """ - new_po = None - new_po = cls._load_local(path) - - logger.info(f'Loaded query object "{type(new_po).__name__}" successfully') - - return new_po - - @classmethod - def _load_local(cls, path): - path = os.path.abspath(os.path.expanduser(path)) - with open(os.path.join(path, "pickle_archive"), "rb") as f: - return _cloudpickle.load(f) - - @classmethod - def _make_serializable(cls, result): - """Convert a result from object query to python data structure that can - easily serialize over network - """ - try: - json.dumps(result) - except TypeError: - raise TypeError( - "Result from object query is not json serializable: " f"{result}" - ) - - return result - - # Returns an array of dictionary that contains the methods and their - # corresponding schema information. - @abc.abstractmethod - def get_methods(self): - return None diff --git a/tabpy/tabpy_tools/rest.py b/tabpy/tabpy_tools/rest.py deleted file mode 100755 index f200c250..00000000 --- a/tabpy/tabpy_tools/rest.py +++ /dev/null @@ -1,423 +0,0 @@ -import abc -from collections.abc import MutableMapping -import logging -import requests -from requests.auth import HTTPBasicAuth -from re import compile -import json as json - - -logger = logging.getLogger(__name__) - - -class ResponseError(Exception): - """Raised when we get an unexpected response.""" - - def __init__(self, response): - super().__init__("Unexpected server response") - self.response = response - self.status_code = response.status_code - - try: - r = response.json() - self.info = r["info"] - self.message = response.json()["message"] - except (json.JSONDecodeError, KeyError): - self.info = None - self.message = response.text - - def __str__(self): - return f"({self.status_code}) " f"{self.message} " f"{self.info}" - - -class RequestsNetworkWrapper: - """The NetworkWrapper wraps the underlying network connection to simplify - the interface a bit. This can be replaced with something that can be built - on some other type of network connection, such as PyCURL. - - This version requires you to instantiate a requests session object to your - liking. It will create a generic session for you if you don't specify it, - which you can modify later. - - For authentication, use:: - - session.auth = (username, password) - """ - - def __init__(self, session=None): - # Set .auth as appropriate. - if session is None: - session = requests.session() - - self.session = session - self.auth = None - - @staticmethod - def raise_error(response): - logger.error( - f"Error with server response. code={response.status_code}; " - f"text={response.text}" - ) - - raise ResponseError(response) - - @staticmethod - def _remove_nones(data): - if isinstance(data, dict): - for k in [k for k, v in data.items() if v is None]: - del data[k] - - def _encode_request(self, data): - self._remove_nones(data) - - if data is not None: - return json.dumps(data) - else: - return None - - def GET(self, url, data, timeout=None): - """Issues a GET request to the URL with the data specified. Returns an - object that is parsed from the response JSON.""" - self._remove_nones(data) - - logger.info(f"GET {url} with {data}") - - response = self.session.get(url, params=data, timeout=timeout, auth=self.auth) - if response.status_code != 200: - self.raise_error(response) - logger.info(f"response={response.text}") - - if response.text == "": - return dict() - else: - return response.json() - - def POST(self, url, data, timeout=None): - """Issues a POST request to the URL with the data specified. Returns an - object that is parsed from the response JSON.""" - data = self._encode_request(data) - - logger.info(f"POST {url} with {data}") - response = self.session.post( - url, - data=data, - headers={"content-type": "application/json"}, - timeout=timeout, - auth=self.auth, - ) - - if response.status_code not in (200, 201): - self.raise_error(response) - - return response.json() - - def PUT(self, url, data, timeout=None): - """Issues a PUT request to the URL with the data specified. Returns an - object that is parsed from the response JSON.""" - data = self._encode_request(data) - - logger.info(f"PUT {url} with {data}") - - response = self.session.put( - url, - data=data, - headers={"content-type": "application/json"}, - timeout=timeout, - auth=self.auth, - ) - if response.status_code != 200: - self.raise_error(response) - - return response.json() - - def DELETE(self, url, data, timeout=None): - """ - Issues a DELETE request to the URL with the data specified. Returns an - object that is parsed from the response JSON. - """ - if data is not None: - data = json.dumps(data) - - logger.info(f"DELETE {url} with {data}") - - response = self.session.delete(url, data=data, timeout=timeout, auth=self.auth) - - if response.status_code <= 499 and response.status_code >= 400: - raise RuntimeError(response.text) - - if response.status_code not in (200, 201, 204): - raise RuntimeError( - f"Error with server response code: {response.status_code}" - ) - - def set_credentials(self, username, password): - """ - Set credentials for all the TabPy client-server communication - where client is tabpy-tools and server is tabpy-server. - - Parameters - ---------- - username : str - User name (login). Username is case insensitive. - - password : str - Password in plain text. - """ - logger.info(f"Setting credentials (username: {username})") - self.auth = HTTPBasicAuth(username, password) - - -class ServiceClient: - """ - A generic service client. - - This will take an endpoint URL and a network_wrapper. You can use the - RequestsNetworkWrapper if you want to use the requests module. The - endpoint URL is prepended to all the requests and forwarded to the network - wrapper. - """ - - def __init__(self, endpoint, network_wrapper=None): - if network_wrapper is None: - network_wrapper = RequestsNetworkWrapper(session=requests.session()) - - self.network_wrapper = network_wrapper - - pattern = compile(".*(:[0-9]+)$") - if not endpoint.endswith("/") and not pattern.match(endpoint): - logger.warning(f"endpoint {endpoint} does not end with '/': appending.") - endpoint = endpoint + "/" - - self.endpoint = endpoint - - def GET(self, url, data=None, timeout=None): - """Prepends self.endpoint to the url and issues a GET request.""" - return self.network_wrapper.GET(self.endpoint + url, data, timeout) - - def POST(self, url, data=None, timeout=None): - """Prepends self.endpoint to the url and issues a POST request.""" - return self.network_wrapper.POST(self.endpoint + url, data, timeout) - - def PUT(self, url, data=None, timeout=None): - """Prepends self.endpoint to the url and issues a PUT request.""" - return self.network_wrapper.PUT(self.endpoint + url, data, timeout) - - def DELETE(self, url, data=None, timeout=None): - """Prepends self.endpoint to the url and issues a DELETE request.""" - self.network_wrapper.DELETE(self.endpoint + url, data, timeout) - - def set_credentials(self, username, password): - """ - Set credentials for all the TabPy client-server communication - where client is tabpy-tools and server is tabpy-server. - - Parameters - ---------- - username : str - User name (login). Username is case insensitive. - - password : str - Password in plain text. - """ - self.network_wrapper.set_credentials(username, password) - - -class RESTProperty: - """A descriptor that will control the type of value stored.""" - - def __init__(self, type, from_json=lambda x: x, to_json=lambda x: x, doc=None): - self.__doc__ = doc - self.type = type - self.from_json = from_json - self.to_json = to_json - - def __get__(self, instance, _): - if instance: - try: - return getattr(instance, self.name) - except AttributeError: - raise AttributeError(f"{self.name} has not been set yet.") - else: - return self - - def __set__(self, instance, value): - if value is not None and not isinstance(value, self.type): - value = self.type(value) - - setattr(instance, self.name, value) - - def __delete__(self, instance): - delattr(instance, self.name) - - -class _RESTMetaclass(abc.ABCMeta): - """The metaclass for RESTObjects. - - This will look into the attributes for the class. If they are a - RESTProperty, then it will add it to the __rest__ set and give it its - name. - - If the bases have __rest__, then it will add them to the __rest__ set as - well. - """ - - def __init__(self, name, bases, dict): - super().__init__(name, bases, dict) - - self.__rest__ = set() - for base in bases: - self.__rest__.update(getattr(base, "__rest__", set())) - - for k, v in dict.items(): - if isinstance(v, RESTProperty): - v.__dict__["name"] = "_" + k - self.__rest__.add(k) - - -class RESTObject(MutableMapping, metaclass=_RESTMetaclass): - """A base class that has methods generally useful for interacting with - REST objects. The attributes are accessible either as dict keys or as - attributes. The object also behaves like a dict, even replicating the - repr() functionality. - - Attributes - ---------- - - __rest__ : set of str - A set of all the rest attribute names. This is generated automatically - and should include all of the base classes' __rest__ as well as any - addition RESTProperty. - - """ - - """ __metaclass__ = _RESTMetaclass""" - - def __init__(self, **kwargs): - """Creates a new instance of the RESTObject. - - Parameters - ---------- - - The parameters depend on __rest__. Each item in __rest__ is searched - for. If found, it is assigned to the instance. Additional parameters - are ignored. - - """ - logger.info(f"Initializing {self.__class__.__name__} from {kwargs}") - for attr in self.__rest__: - if attr in kwargs: - setattr(self, attr, kwargs.pop(attr)) - - def __repr__(self): - return ( - "{" + ", ".join([repr(k) + ": " + repr(v) for k, v in self.items()]) + "}" - ) - - @classmethod - def from_json(cls, data): - """Returns a new class object with data populated from json.loads().""" - attrs = {} - for attr in cls.__rest__: - try: - value = data[attr] - except KeyError: - pass - else: - prop = cls.__dict__[attr] - attrs[attr] = prop.from_json(value) - return cls(**attrs) - - def to_json(self): - """Returns a dict representing this object. This dict will be sent to - json.dumps(). - - The keys are the items in __rest__ and the values are the current - values. If missing, it is not included. - """ - result = {} - for attr in self.__rest__: - prop = getattr(self.__class__, attr) - try: - result[attr] = prop.to_json(getattr(self, attr)) - except AttributeError: - pass - - return result - - def __eq__(self, other): - return isinstance(self, type(other)) and all( - (getattr(self, a) == getattr(other, a) for a in self.__rest__) - ) - - def __len__(self): - return len([a for a in self.__rest__ if hasattr(self, "_" + a)]) - - def __iter__(self): - return iter([a for a in self.__rest__ if hasattr(self, "_" + a)]) - - def __getitem__(self, item): - if item not in self.__rest__: - raise KeyError(item) - try: - return getattr(self, item) - except AttributeError: - raise KeyError(item) - - def __setitem__(self, item, value): - if item not in self.__rest__: - raise KeyError(item) - setattr(self, item, value) - - def __delitem__(self, item): - if item not in self.__rest__: - raise KeyError(item) - try: - delattr(self, "_" + item) - except AttributeError: - raise KeyError(item) - - def __contains__(self, item): - return item in self.__rest__ - - -def enum(*values, **kwargs): - """Generates an enum function that only accepts particular values. Other - values will raise a ValueError. - - Parameters - ---------- - - values : list - These are the acceptable values. - - type : type - The acceptable types of values. Values will be converted before being - checked against the allowed values. If not specified, no conversion - will be performed. - - Example - ------- - - >>> my_enum = enum(1, 2, 3, 4, 5, type=int) - >>> a = my_enum(1) - >>> b = my_enum(2) - >>> c = mu_enum(6) # Raises ValueError - - """ - if len(values) < 1: - raise ValueError("At least one value is required.") - enum_type = kwargs.pop("type", str) - if kwargs: - raise TypeError(f'Unexpected parameters: {", ".join(kwargs.keys())}') - - def __new__(cls, value): - if value not in cls.values: - raise ValueError( - f"{value} is an unexpected value. " f"Expected one of {cls.values}" - ) - - return super(enum, cls).__new__(cls, value) - - enum = type("Enum", (enum_type,), {"values": values, "__new__": __new__}) - - return enum diff --git a/tabpy/tabpy_tools/rest_client.py b/tabpy/tabpy_tools/rest_client.py deleted file mode 100755 index eb0ef211..00000000 --- a/tabpy/tabpy_tools/rest_client.py +++ /dev/null @@ -1,252 +0,0 @@ -from .rest import RESTObject, RESTProperty -from datetime import datetime - - -def from_epoch(value): - if isinstance(value, datetime): - return value - else: - return datetime.utcfromtimestamp(value) - - -def to_epoch(value): - return (value - datetime(1970, 1, 1)).total_seconds() - - -class Endpoint(RESTObject): - """Represents an endpoint. - - Note that not every attribute is returned as part of the GET. - - Attributes - ---------- - - name : str - The name of the endpoint. Valid names include ``[a-zA-Z0-9_\\- ]+`` - type : str - The type of endpoint. The types include "alias", "model". - version : int - The version of this endpoint. Initial versions have version on 1. New - versions increment this by 1. - description : str - A human-readable description of the endpoint. - dependencies: list - A list of endpoints that this endpoint depends on. - methods : list - ??? - """ - - name = RESTProperty(str) - type = RESTProperty(str) - version = RESTProperty(int) - description = RESTProperty(str) - dependencies = RESTProperty(list) - methods = RESTProperty(list) - creation_time = RESTProperty(datetime, from_epoch, to_epoch) - last_modified_time = RESTProperty(datetime, from_epoch, to_epoch) - evaluator = RESTProperty(str) - schema_version = RESTProperty(int) - schema = RESTProperty(str) - - def __new__(cls, **kwargs): - """Dispatch to the appropriate class.""" - cls = {"alias": AliasEndpoint, "model": ModelEndpoint}[kwargs["type"]] - - """return object.__new__(cls, **kwargs)""" - """ modified for Python 3""" - return object.__new__(cls) - - def __eq__(self, other): - return ( - self.name == other.name - and self.type == other.type - and self.version == other.version - and self.description == other.description - and self.dependencies == other.dependencies - and self.methods == other.methods - and self.evaluator == other.evaluator - and self.schema_version == other.schema_version - and self.schema == other.schema - ) - - -class ModelEndpoint(Endpoint): - """Represents a model endpoint. - - src_path : str - - The local file path to the source of this object. - - required_files : str - - The local file path to the directory containing the - required files. - - required_packages : str - - The local file path to the directory containing the - required packages. - - """ - - src_path = RESTProperty(str) - required_files = RESTProperty(list) - required_packages = RESTProperty(list) - required_packages_dst_path = RESTProperty(str) - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.type = "model" - - def __eq__(self, other): - return ( - super().__eq__(other) - and self.required_files == other.required_files - and self.required_packages == other.required_packages - ) - - -class AliasEndpoint(Endpoint): - """Represents an alias Endpoint. - - target : str - - The endpoint that this is an alias for. - - """ - - target = RESTProperty(str) - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.type = "alias" - - -class RESTServiceClient: - """A thin client for the REST Service.""" - - def __init__(self, service_client): - self.service_client = service_client - self.query_timeout = None - - def get_info(self): - """Returns the /info""" - return self.service_client.GET("info") - - def query(self, name, *args, **kwargs): - """Performs a query. Either specify *args or **kwargs, not both. - Respects query_timeout.""" - if args and kwargs: - raise ValueError( - "Mixing of keyword arguments and positional arguments when " - "querying an endpoint is not supported." - ) - return self.service_client.POST( - "query/" + name, data={"data": args or kwargs}, timeout=self.query_timeout - ) - - def get_endpoint_upload_destination(self): - """Returns a dict representing where endpoint data should be uploaded. - - Returns - ------- - dict - Keys include: - * path: a local file path. - - Note: In the future, other paths and parameters may be supported. - - Note: At this time, the response should not change over time. - """ - return self.service_client.GET("configurations/endpoint_upload_destination") - - def get_endpoints(self, type=None): - """Returns endpoints from the management API. - - Parameters - ---------- - - type : str - The type of endpoint to return. None will include all endpoints. - Other options are 'model' and 'alias'. - """ - result = {} - for name, attrs in self.service_client.GET("endpoints", {"type": type}).items(): - endpoint = Endpoint.from_json(attrs) - endpoint.name = name - result[name] = endpoint - return result - - def get_endpoint(self, endpoint_name): - """Returns an endpoints from the management API given its name. - - Parameters - ---------- - - endpoint_name : str - - The name of the endpoint. - """ - ((name, attrs),) = self.service_client.GET("endpoints/" + endpoint_name).items() - endpoint = Endpoint.from_json(attrs) - endpoint.name = name - return endpoint - - def add_endpoint(self, endpoint): - """Adds an endpoint through the management API. - - Parameters - ---------- - - endpoint : Endpoint - """ - return self.service_client.POST("endpoints", endpoint.to_json()) - - def set_endpoint(self, endpoint): - """Updates an endpoint through the management API. - - Parameters - ---------- - - endpoint : Endpoint - - The endpoint to update. - """ - return self.service_client.PUT("endpoints/" + endpoint.name, endpoint.to_json()) - - def remove_endpoint(self, endpoint_name): - """Deletes an endpoint through the management API. - - Parameters - ---------- - - endpoint_name : str - - The endpoint to delete. - """ - self.service_client.DELETE("endpoints/" + endpoint_name) - - def get_status(self): - """Returns the status of the server. - - Returns - ------- - - dict - """ - return self.service_client.GET("status") - - def set_credentials(self, username, password): - """ - Set credentials for all the TabPy client-server communication - where client is tabpy-tools and server is tabpy-server. - - Parameters - ---------- - username : str - User name (login). Username is case insensitive. - - password : str - Password in plain text. - """ - self.service_client.set_credentials(username, password) diff --git a/tabpy/tabpy_tools/schema.py b/tabpy/tabpy_tools/schema.py deleted file mode 100755 index ba36bae2..00000000 --- a/tabpy/tabpy_tools/schema.py +++ /dev/null @@ -1,108 +0,0 @@ -import logging -import genson -import jsonschema - - -logger = logging.getLogger(__name__) - - -def _generate_schema_from_example_and_description(input, description): - """ - With an example input, a schema is automatically generated that conforms - to the example in json-schema.org. The description given by the users - is then added to the schema. - """ - s = genson.SchemaBuilder(None) - s.add_object(input) - input_schema = s.to_schema() - - if description is not None: - if "properties" in input_schema: - # Case for input = {'x':1}, input_description='not a dict' - if not isinstance(description, dict): - msg = f"{input} and {description} do not match" - logger.error(msg) - raise Exception(msg) - - for key in description: - # Case for input = {'x':1}, - # input_description={'x':'x value', 'y':'y value'} - if key not in input_schema["properties"]: - msg = f"{key} not found in {input}" - logger.error(msg) - raise Exception(msg) - else: - input_schema["properties"][key]["description"] = description[key] - else: - if isinstance(description, dict): - raise Exception(f"{input} and {description} do not match") - else: - input_schema["description"] = description - - try: - # This should not fail unless there are bugs with either genson or - # jsonschema. - jsonschema.validate(input, input_schema) - except Exception as e: - logger.error(f"Internal error validating schema: {str(e)}") - raise - - return input_schema - - -def generate_schema(input, output, input_description=None, output_description=None): - """ - Generate schema from a given sample input and output. - A generated schema can be passed to a server together with a function to - annotate it with information about input and output parameters, and - examples thereof. The schema needs to follow the conventions of JSON Schema - (see json-schema.org). - - Parameters - ----------- - input : any python type | dict - output: any python type | dict - input_description : str | dict, optional - output_description : str | dict, optional - - References - ----------- - - `Json Schema ` - - Examples - ---------- - .. sourcecode:: python - For just one input parameter, state the example directly. - >>> from tabpy_tools.schema import generate_schema - >>> schema = generate_schema( - input=5, - output=25, - input_description='input value', - output_description='the squared value of input') - >>> schema - {'sample': 5, - 'input': {'type': 'integer', 'description': 'input value'}, - 'output': {'type': 'integer', 'description': 'the squared value of input'}} - For two or more input parameters, specify them using a dictionary. - >>> import graphlab - >>> schema = generate_schema( - input={'x': 3, 'y': 2}, - output=6, - input_description={'x': 'value of x', - 'y': 'value of y'}, - output_description='x times y') - >>> schema - {'sample': {'y': 2, 'x': 3}, - 'input': {'required': ['x', 'y'], - 'type': 'object', - 'properties': {'y': {'type': 'integer', 'description': 'value of y'}, - 'x': {'type': 'integer', 'description': 'value of x'}}}, - 'output': {'type': 'integer', 'description': 'x times y'}} - """ # noqa: E501 - input_schema = _generate_schema_from_example_and_description( - input, input_description - ) - output_schema = _generate_schema_from_example_and_description( - output, output_description - ) - return {"input": input_schema, "sample": input, "output": output_schema} diff --git a/tests/unit/server_tests/test_config.py b/tests/unit/server_tests/test_config.py index 84c6cd5f..c9aa205a 100644 --- a/tests/unit/server_tests/test_config.py +++ b/tests/unit/server_tests/test_config.py @@ -24,14 +24,12 @@ def test_config_file_does_not_exist(self): @patch("tabpy.tabpy_server.app.app.TabPyState") @patch("tabpy.tabpy_server.app.app._get_state_from_file") - @patch("tabpy.tabpy_server.app.app.PythonServiceHandler") @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) @patch("tabpy.tabpy_server.app.app.os") def test_no_config_file( self, mock_os, mock_path_exists, - mock_psws, mock_management_util, mock_tabpy_state, ): @@ -46,7 +44,6 @@ def test_no_config_file( TabPyApp(None) - self.assertEqual(len(mock_psws.mock_calls), 1) self.assertEqual(len(mock_tabpy_state.mock_calls), 1) self.assertEqual(len(mock_path_exists.mock_calls), 1) self.assertTrue(len(mock_management_util.mock_calls) > 0) @@ -54,14 +51,12 @@ def test_no_config_file( @patch("tabpy.tabpy_server.app.app.TabPyState") @patch("tabpy.tabpy_server.app.app._get_state_from_file") - @patch("tabpy.tabpy_server.app.app.PythonServiceHandler") @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=False) @patch("tabpy.tabpy_server.app.app.os") def test_no_state_ini_file_or_state_dir( self, mock_os, mock_path_exists, - mock_psws, mock_management_util, mock_tabpy_state, ): @@ -79,14 +74,12 @@ def tearDown(self): @patch("tabpy.tabpy_server.app.app.TabPyState") @patch("tabpy.tabpy_server.app.app._get_state_from_file") - @patch("tabpy.tabpy_server.app.app.PythonServiceHandler") @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) @patch("tabpy.tabpy_server.app.app.os") def test_config_file_present( self, mock_os, mock_path_exists, - mock_psws, mock_management_util, mock_tabpy_state, ): diff --git a/tests/unit/server_tests/test_endpoint_handler.py b/tests/unit/server_tests/test_endpoint_handler.py index a9393c42..d2d167c7 100755 --- a/tests/unit/server_tests/test_endpoint_handler.py +++ b/tests/unit/server_tests/test_endpoint_handler.py @@ -1,6 +1,5 @@ import base64 import os -import sys import tempfile from tabpy.tabpy_server.app.app import TabPyApp diff --git a/tests/unit/server_tests/test_evaluation_plane_handler.py b/tests/unit/server_tests/test_evaluation_plane_handler.py index 49b67dfb..7b63f5c4 100755 --- a/tests/unit/server_tests/test_evaluation_plane_handler.py +++ b/tests/unit/server_tests/test_evaluation_plane_handler.py @@ -2,7 +2,6 @@ import os import tempfile -from argparse import Namespace from tabpy.tabpy_server.app.app import TabPyApp from tabpy.tabpy_server.handlers.util import hash_password from tornado.testing import AsyncHTTPTestCase