From 37e90ce702434e7fd79bc46c0ab319a7b9c51694 Mon Sep 17 00:00:00 2001 From: vaimdev Date: Tue, 29 Nov 2022 21:10:05 +0800 Subject: [PATCH] Personal key (#420) * feat: Add org_name for logging in and corresponding Dataset upload with org_name intact * fix: org_name global value retrieving bug found by test script * refactor : clean up debugging code * unittest: Add test scenarios for login with org_name * fix: send org_name when create_file, privacy setting with mode_action * fix: Remove debugging code, fix backward incompatibility * fix: request json error * fix: only pass org_name in json if org_name is not None, which failed in old server * fix: arrow_uploader login handling for org_name * fix (security): Adjust default mode_action to READ when share as public, set to WRITE only when share as private * fix (test): Fix test scripts and add more test cases * documetation on org-name parameter * documentation: Update README.md * documentation: Add detail * fix: rearrange the sequence of exception thrown when new pygraphistry login with org on old server * doc: Add documentation for sharing tutorial notebook on sharing within organization, minor correction for other documentation * docs(readme.md): orgs * docs(readme.md): org mode privacy * feat (sso login): Add initial code to call API to get SSO login page and use state to retrieve token * fix (sso): Login with SSO to obtain jwt token * fix (cleanup): Clean up code for SSO login * feat : Use timeout=None to replace flag for blocking mode * refactor (sso): Handle cases for ipython console or jupyter notebook * doc (register function for SSO): Added docstring doc for register function for SSO login * doc (other SSO related function): Add docstring documentation * fix (sso login): to allow blocking mode for running in notebook * fix (register function): do not throw exception if register function does not pass in username/pwd or org_name * fix (typecheck): type check failure fix * fix (org_name): json object conversion problem for org_name if using property instead of calling function directly. Add some debug code * refactor (print -> logger.debug): change print to logger.debug * fix (org_name should only pass if org_name has value) * fix (type imcompatible) test failure * fix (org_name to cope with old and new server) * fix (mypy): fix mypy issue * fix (login): allow login with token * wip (login with personal key): added personal key id and personal key for the login/register * feat (personal key): Add personal key login capability initial code * wip (sso login): SSO login fix for site wide * Update README.md * wip (site wide sso login): when no org slug and idp_name passed in register, try with site wide SSO * fix (sso): Display SSO timeout exception in better and understandable * fix (register): adjust register logic to check the missing username, password, missing personal key id & personal key. Add pytest for testing the scenario * feat (sso enhancement): Add is_sso_login parameter to handle whether to do sso login when register * feat (test scripts): add test script for the register function * fix (arrow_uploader): bug in register with sso_login * feat (organization): Add switch_org function to allow switching of organization * fix (switch org): Fix switch org API * fix (typecheck and lint) * feat (add test script): add test script for switch org * refactor (personal key to personal key secret): refactor variable and fix test scripts * fix: docstring/typecheck Optional fix * fix: typecheck docstring * fix (raise exception instead of print): use a messages.py to keep the message as constant * docs(readme): login * fix : raise Exception instead of printing * fix (debug info): remove debugging info * fix(docs): personal key * wip (switch org): fix after switch org with org_name('xxx'), plotting does not take the updated org_name. Fix done, pending to remove debugging code * fix (clean up): Clean up debugging code for switch org * fix (mypy newer version issues): Optional[xxx] = None, instead of xxx = None * fix (mypy): fix the default value of layer to None * fix (test scripts): add ipython in dev_extra(stubs), fix unauthenticated issue in test_ipython * fix (personal key): fix personal key login does not switch org_name * fix(logging): print to logger * fix(types) Co-authored-by: lmeyerov --- CHANGELOG.md | 9 + README.md | 16 +- graphistry/ArrowFileUploader.py | 4 +- graphistry/PlotterBase.py | 13 +- graphistry/__init__.py | 1 + graphistry/arrow_uploader.py | 170 +++++++++- graphistry/bolt_util.py | 2 +- graphistry/dgl_utils.py | 8 +- graphistry/exceptions.py | 12 + graphistry/feature_utils.py | 4 +- graphistry/gremlin.py | 10 +- graphistry/layout/utils/layoutVertex.py | 3 +- graphistry/messages.py | 14 + graphistry/pygraphistry.py | 422 ++++++++++++++++++++++-- graphistry/tests/test_arrow_uploader.py | 59 ++++ graphistry/tests/test_ipython.py | 5 + graphistry/tests/test_pygraphistry.py | 107 +++++- setup.py | 2 +- 18 files changed, 791 insertions(+), 70 deletions(-) create mode 100644 graphistry/exceptions.py create mode 100644 graphistry/messages.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 55c324df1..a6298dd35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Development] +### Added + +* Personal keys: `register(personal_key_id=..., personal_key_secret=...)` +* SSO: `register()` (no user/pass), `register(idp_name=...)` (org-specific IDP) + +### Fixed + +* Type errors + ## [0.28.4 - 2022-10-22] ### Added diff --git a/README.md b/README.md index 5481f6abf..64444639c 100644 --- a/README.md +++ b/README.md @@ -51,16 +51,18 @@ You can use PyGraphistry with traditional Python data sources like CSVs, SQL, Ne ```python # pip install --user graphistry # minimal # pip install --user graphistry[bolt,gremlin,nodexl,igraph,networkx] # data plugins - # Requires Python 3.8+ (for scikit-learn 1.0+): - # pip install --user graphistry[umap-learn] # UMAP autoML (without text support) - # pip install --user graphistry[ai] # Full UMAP + GNN autoML, including sentence transformers (1GB+) + # AI modules: Python 3.8+ with scikit-learn 1.0+: + # pip install --user graphistry[umap-learn] # Lightweight: UMAP autoML (without text support); scikit-learn 1.0+ + # pip install --user graphistry[ai] # Heavy: Full UMAP + GNN autoML, including sentence transformers (1GB+) import graphistry graphistry.register(api=3, username='abc', password='xyz') # Free: hub.graphistry.com - - #graphistry.register(..., org_name='my-org') # Upload into an organization account - #graphistry.register(..., protocol='http', server='my.site.ngo') # Use with a self-hosted server - + #graphistry.register(..., personal_key_id='pkey_id', personal_key_secret='pkey_secret') # Key instead of username+password+org_name + #graphistry.register(..., is_sso_login=True) # SSO instead of password + #graphistry.register(..., org_name='my-org') # Upload into an organization account vs personal + #graphistry.register(..., protocol='https', server='my.site.ngo') # Use with a self-hosted server + # ... and if client (browser) URLs are different than python server<> graphistry server uploads + #graphistry.register(..., client_protocol_hostname='https://public.acme.co') ``` * **Notebook-friendly:** PyGraphistry plays well with interactive notebooks like [Jupyter](http://ipython.org), [Zeppelin](https://zeppelin.incubator.apache.org/), and [Databricks](http://databricks.com). Process, visualize, and drill into with graphs directly within your notebooks: diff --git a/graphistry/ArrowFileUploader.py b/graphistry/ArrowFileUploader.py index f5a06aedd..9c0fd2364 100644 --- a/graphistry/ArrowFileUploader.py +++ b/graphistry/ArrowFileUploader.py @@ -1,6 +1,6 @@ import pyarrow as pa, requests, sys from functools import lru_cache -from typing import Any, Tuple +from typing import Any, Tuple, Optional from weakref import WeakKeyDictionary from .util import setup_logger logger = setup_logger(__name__) @@ -120,7 +120,7 @@ def post_arrow(self, arr: pa.Table, file_id: str, url_opts: str = 'erase=true') ### def create_and_post_file( - self, arr: pa.Table, file_id: str = None, file_opts: dict = {}, upload_url_opts: str = 'erase=true', memoize: bool = True + self, arr: pa.Table, file_id: Optional[str] = None, file_opts: dict = {}, upload_url_opts: str = 'erase=true', memoize: bool = True ) -> Tuple[str, dict]: """ Create file and upload data for it. diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py index 2d230bd17..444918164 100644 --- a/graphistry/PlotterBase.py +++ b/graphistry/PlotterBase.py @@ -1368,6 +1368,8 @@ def plot( .plot(es) """ + from .pygraphistry import PyGraphistry + logger.debug("1. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name())) if graph is None: if self._edges is None: @@ -1381,15 +1383,19 @@ def plot( self._check_mandatory_bindings(not isinstance(n, type(None))) - from .pygraphistry import PyGraphistry + # from .pygraphistry import PyGraphistry api_version = PyGraphistry.api_version() + logger.debug("2. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name())) if api_version == 1: dataset = self._plot_dispatch(g, n, name, description, 'json', self._style, memoize) if skip_upload: return dataset info = PyGraphistry._etl1(dataset) elif api_version == 3: + logger.debug("3. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name())) PyGraphistry.refresh() + logger.debug("4. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name())) + dataset = self._plot_dispatch(g, n, name, description, 'arrow', self._style, memoize) if skip_upload: return dataset @@ -1903,7 +1909,6 @@ def _make_dataset(self, edges, nodes, name, description, mode, metadata=None, me warn('Graph has no edges, may have rendering issues') except: 1 - #compatibility checks if mode == 'json': if not (metadata is None): @@ -1958,7 +1963,6 @@ def flatten_categorical(df): def _make_arrow_dataset(self, edges: pa.Table, nodes: pa.Table, name: str, description: str, metadata) -> ArrowUploader: from .pygraphistry import PyGraphistry - au : ArrowUploader = ArrowUploader( server_base_path=PyGraphistry.protocol() + '://' + PyGraphistry.server(), edges=edges, nodes=nodes, @@ -1971,8 +1975,7 @@ def _make_arrow_dataset(self, edges: pa.Table, nodes: pa.Table, name: str, descr 'agentversion': sys.modules['graphistry'].__version__, # type: ignore **(metadata or {}) }, - certificate_validation=PyGraphistry.certificate_validation(), - org_name=PyGraphistry.org_name()) + certificate_validation=PyGraphistry.certificate_validation()) au.edge_encodings = au.g_to_edge_encodings(self) au.node_encodings = au.g_to_node_encodings(self) diff --git a/graphistry/__init__.py b/graphistry/__init__.py index ca9446818..dd13cb60c 100644 --- a/graphistry/__init__.py +++ b/graphistry/__init__.py @@ -4,6 +4,7 @@ protocol, server, register, + sso_get_token, privacy, login, refresh, diff --git a/graphistry/arrow_uploader.py b/graphistry/arrow_uploader.py index 141721a11..1309459e2 100644 --- a/graphistry/arrow_uploader.py +++ b/graphistry/arrow_uploader.py @@ -1,6 +1,7 @@ from typing import List, Optional import io, pyarrow as pa, requests, sys + from .ArrowFileUploader import ArrowFileUploader from .util import setup_logger logger = setup_logger(__name__) @@ -18,12 +19,12 @@ def token(self, token: str): self.__token = token @property - def org_name(self) -> str: + def org_name(self) -> Optional[str]: return self.__org_name @org_name.setter - def org_name(self, org_name: str): - self.__org_name = org_name + def org_name(self, org_name: str) -> None: + self.__org_name: Optional[str] = org_name @property def dataset_id(self) -> str: @@ -136,6 +137,19 @@ def certificate_validation(self, certificate_validation): ########################################################################3 + # @property + # def sso_state(self) -> str: + # return getattr(self, '__sso_state', "") + + ########################################################################3 + + # @property + # def sso_auth_url(self) -> str: + # return getattr(self, '__sso_auth_url') + + ########################################################################3 + + def __init__(self, server_base_path='http://nginx', view_base_path='http://localhost', name = None, @@ -146,6 +160,7 @@ def __init__(self, metadata = None, certificate_validation = True, org_name: Optional[str] = None): + self.__name = name self.__description = description self.__server_base_path = server_base_path @@ -158,30 +173,65 @@ def __init__(self, self.__edge_encodings = edge_encodings self.__metadata = metadata self.__certificate_validation = certificate_validation - if org_name is not None: + self.__org_name = org_name if org_name else None + + if org_name: self.__org_name = org_name - + else: + # check current org_name + from .pygraphistry import PyGraphistry + if 'org_name' in PyGraphistry._config: + logger.debug("@ArrowUploader.__init__: There is an org_name : {}".format(PyGraphistry._config['org_name'])) + self.__org_name = PyGraphistry._config['org_name'] + else: + self.__org_name = None + + logger.debug("2. @ArrowUploader.__init__: After set self.org_name: {}, self.__org_name : {}".format(self.org_name, self.__org_name)) + + def login(self, username, password, org_name=None): - from .pygraphistry import PyGraphistry + # base_path = self.server_base_path + + json_data = {'username': username, 'password': password} + if org_name: + json_data.update({"org_name": org_name}) - base_path = self.server_base_path out = requests.post( - f'{base_path}/api-token-auth/', + f'{self.server_base_path}/api-token-auth/', + verify=self.certificate_validation, + json=json_data) + + return self._handle_login_response(out, org_name) + + def pkey_login(self, personal_key_id, personal_key_secret, org_name=None): + # json_data = {'personal_key_id': personal_key_id, 'personal_key_secret': personal_key} + json_data = {} + if org_name: + json_data.update({"org_name": org_name}) + + headers = {"Authorization": f'PersonalKey {personal_key_id}:{personal_key_secret}'} + + url = f'{self.server_base_path}/api/v2/auth/pkey/jwt/' + + out = requests.get( + url, verify=self.certificate_validation, - json={'username': username, 'password': password, "org_name": org_name}) + json=json_data, headers=headers) + return self._handle_login_response(out, org_name) + + def _handle_login_response(self, out, org_name): + from .pygraphistry import PyGraphistry json_response = None try: json_response = out.json() - if not ('token' in json_response): raise Exception(out.text) org = json_response.get('active_organization',{}) logged_in_org_name = org.get('slug', None) - if org_name: # caller pass in org_name if not logged_in_org_name: # no active_organization in JWT payload - raise Exception("Server does not support organization, please omit org_name") + raise Exception("You are not authorized to the organization '{}', or server does not support organization, please omit org_name parameter".format(org_name)) else: # if JWT response with org_name different than the pass in org_name # => org_name not found and return default organization (currently is personal org) @@ -195,9 +245,16 @@ def login(self, username, password, org_name=None): raise Exception("Organization {} is not found".format(org_name)) if not is_member: - raise Exception("You are not a member of {}".format(org_name)) + raise Exception("You are not authorized or not a member of {}".format(org_name)) - PyGraphistry.org_name(logged_in_org_name) + if logged_in_org_name is None and org_name is None: + if 'org_name' in PyGraphistry._config: + del PyGraphistry._config['org_name'] + else: + if org_name in PyGraphistry._config: + logger.debug("@ArrowUploder, handle login reponse, org_name: {}".format(PyGraphistry._config['org_name'])) + PyGraphistry._config['org_name'] = logged_in_org_name + # PyGraphistry.org_name(logged_in_org_name) except Exception: logger.error('Error: %s', out, exc_info=True) raise @@ -206,6 +263,87 @@ def login(self, username, password, org_name=None): return self + def sso_login(self, org_name=None, idp_name=None): + """ + Koa, 04 May 2022 Get SSO login auth_url or token + """ + # from .pygraphistry import PyGraphistry + base_path = self.server_base_path + + if org_name is None and idp_name is None: + print("Login to site wide SSO") + url = f'{base_path}/api/v2/g/sso/oidc/login/' + elif org_name is not None and idp_name is None: + print("Login to {} organization level SSO".format(org_name)) + url = f'{base_path}/api/v2/o/{org_name}/sso/oidc/login/' + elif org_name is not None and idp_name is not None: + print("Login to {} idp {} SSO".format(org_name, idp_name)) + url = f'{base_path}/api/v2/o/{org_name}/sso/oidc/login/{idp_name}/' + + # print("url : {}".format(url)) + out = requests.post( + url, data={'client-type': 'pygraphistry'}, + verify=self.certificate_validation + ) + # print(out.text) + json_response = None + try: + json_response = out.json() + logger.debug("@ArrowUploader.sso_login, json_response: {}".format(json_response)) + self.token = None + if not ('status' in json_response): + raise Exception(out.text) + else: + if json_response['status'] == 'OK': + logger.debug("@ArrowUploader.sso_login, json_data : {}".format(json_response['data'])) + if 'state' in json_response['data']: + self.sso_state = json_response['data']['state'] + self.sso_auth_url = json_response['data']['auth_url'] + else: + self.token = json_response['data']['token'] + elif json_response['status'] == 'ERR': + raise Exception(json_response['message']) + + except Exception: + logger.error('Error: %s', out, exc_info=True) + raise + + return self + + def sso_get_token(self, state): + """ + Koa, 04 May 2022 Use state to get token + """ + + # from .pygraphistry import PyGraphistry + + base_path = self.server_base_path + out = requests.get( + f'{base_path}/api/v2/o/sso/oidc/jwt/{state}/', + verify=self.certificate_validation + ) + json_response = None + try: + json_response = out.json() + # print("get_jwt : {}".format(json_response)) + self.token = None + if not ('status' in json_response): + raise Exception(out.text) + else: + if json_response['status'] == 'OK': + if 'token' in json_response['data']: + self.token = json_response['data']['token'] + if 'active_organization' in json_response['data']: + logger.debug("@ArrowUploader.sso_get_token, org_name: {}".format(json_response['data']['active_organization']['slug'])) + self.org_name = json_response['data']['active_organization']['slug'] + + except Exception: + logger.error('Error: %s', out, exc_info=True) + # raise + + return self + + def refresh(self, token=None): if token is None: token = self.token @@ -240,10 +378,9 @@ def verify(self, token=None) -> bool: def create_dataset(self, json): # noqa: F811 tok = self.token - if self.org_name: json['org_name'] = self.org_name - + logger.debug("@ArrowUploder create_dataset json: {}".format(json)) res = requests.post( self.server_base_path + '/api/v2/upload/datasets/', verify=self.certificate_validation, @@ -351,6 +488,7 @@ def post(self, as_files: bool = True, memoize: bool = True): """ Note: likely want to pair with self.maybe_post_share_link(g) """ + logger.debug("@ArrowUploader.post, self.org_name : {}".format(self.org_name)) if as_files: file_uploader = ArrowFileUploader(self) diff --git a/graphistry/bolt_util.py b/graphistry/bolt_util.py index 78b8db724..3f1f93689 100644 --- a/graphistry/bolt_util.py +++ b/graphistry/bolt_util.py @@ -82,7 +82,7 @@ def flatten_spatial_col(df : pd.DataFrame, col : str) -> pd.DataFrame: # noqa: for prop in ['x', 'y', 'z', 'srid', 'longtitude', 'latitude', 'height']: try: # v4.x + v5.x - s = df[col].apply(lambda v: getattr(v, prop, None)) + s = df[col].apply(lambda v: getattr(v, prop, None)) # type: ignore if len(s.dropna()) > 0: out_df[f'{col}_{prop}'] = s except: diff --git a/graphistry/dgl_utils.py b/graphistry/dgl_utils.py index 960c6ca36..6e66f4971 100644 --- a/graphistry/dgl_utils.py +++ b/graphistry/dgl_utils.py @@ -165,7 +165,7 @@ def pandas_to_sparse_adjacency(df, src, dst, weight_col): # ############################################################################## def pandas_to_dgl_graph( - df: pd.DataFrame, src: str, dst: str, weight_col: str = None, device: str = "cpu" + df: pd.DataFrame, src: str, dst: str, weight_col: Optional[str] = None, device: str = "cpu" ): """Turns an edge DataFrame with named src and dst nodes, to DGL graph :eg @@ -431,13 +431,13 @@ def build_gnn( X_edges: XSymbolic = None, y_nodes: YSymbolic = None, y_edges: YSymbolic = None, - weight_column: str = None, + weight_column: Optional[str] = None, reuse_if_existing=True, featurize_edges =True, use_node_scaler: str = "zscale", - use_node_scaler_target: str = None, + use_node_scaler_target: Optional[str] = None, use_edge_scaler: str = "zscale", - use_edge_scaler_target: str = None, + use_edge_scaler_target: Optional[str] = None, train_split: float = 0.8, device: str = "cpu", inplace: bool = False, diff --git a/graphistry/exceptions.py b/graphistry/exceptions.py new file mode 100644 index 000000000..1b7c12d15 --- /dev/null +++ b/graphistry/exceptions.py @@ -0,0 +1,12 @@ +class SsoException(Exception): + """ + Koa, 15 Sep 2022 Custom Base Exception to handle Sso exception scenario + """ + pass + + +class SsoRetrieveTokenTimeoutException(SsoException): + """ + Koa, 15 Sep 2022 Custom Exception to Sso retrieve token time out exception scenario + """ + pass diff --git a/graphistry/feature_utils.py b/graphistry/feature_utils.py index a3d8d1f7f..eeae74f4c 100644 --- a/graphistry/feature_utils.py +++ b/graphistry/feature_utils.py @@ -365,7 +365,7 @@ def check_if_currency(x: str): logger.warning(e) return False - mask = df[col].apply(lambda x: check_if_currency) + mask = df[col].apply(check_if_currency) return mask @@ -1758,7 +1758,7 @@ def scale(self, df, ydf=None, set_scaler=False, *args, **kwargs): def prune_weighted_edges_df_and_relabel_nodes( - wdf: pd.DataFrame, scale: float = 0.1, index_to_nodes_dict: Dict = None + wdf: pd.DataFrame, scale: float = 0.1, index_to_nodes_dict: Optional[Dict] = None ) -> pd.DataFrame: """ Prune the weighted edge DataFrame so to return high diff --git a/graphistry/gremlin.py b/graphistry/gremlin.py index 1a09bf93a..bb28dc801 100644 --- a/graphistry/gremlin.py +++ b/graphistry/gremlin.py @@ -825,11 +825,11 @@ def __init__(self, *args, **kwargs): def cosmos( self, - COSMOS_ACCOUNT: str = None, - COSMOS_DB: str = None, - COSMOS_CONTAINER: str = None, - COSMOS_PRIMARY_KEY: str = None, - gremlin_client: Client = None + COSMOS_ACCOUNT: Optional[str] = None, + COSMOS_DB: Optional[str] = None, + COSMOS_CONTAINER: Optional[str] = None, + COSMOS_PRIMARY_KEY: Optional[str] = None, + gremlin_client: Optional[Client] = None ): """ Provide credentials as arguments, as environment variables, or by providing a gremlinpython client diff --git a/graphistry/layout/utils/layoutVertex.py b/graphistry/layout/utils/layoutVertex.py index ca677159c..4d66ba6d3 100644 --- a/graphistry/layout/utils/layoutVertex.py +++ b/graphistry/layout/utils/layoutVertex.py @@ -1,4 +1,5 @@ # DEPRECRATED: Non-vector operators over non-vectorized data +from typing import Optional class LayoutVertex(object): """ @@ -13,7 +14,7 @@ class LayoutVertex(object): bar (float): the current barycenter of the vertex """ - def __init__(self, layer: int = None, is_dummy = 0): + def __init__(self, layer: Optional[int] = None, is_dummy = 0): self.layer = layer # layer number self.dummy = is_dummy self.root = None diff --git a/graphistry/messages.py b/graphistry/messages.py new file mode 100644 index 000000000..2bc203d35 --- /dev/null +++ b/graphistry/messages.py @@ -0,0 +1,14 @@ +# message (exception, error etc) constant + +MSG_REGISTER_MISSING_PASSWORD = 'Error: username exists but missing password' +MSG_REGISTER_MISSING_USERNAME = 'Error: password exist but missing username' + +MSG_REGISTER_MISSING_PKEY_SECRET = 'Error: personal key id exists but missing personal key secret' +MSG_REGISTER_MISSING_PKEY_ID = 'Error: personal key secret exists but missing personal key id' + + +MSG_REGISTER_ENTER_SSO_LOGIN = 'No username/password, personal key id/secret & token provided, enter SSO login' + +MSG_SWITCH_ORG_SUCCESS = "Switched to organization: {}" +MSG_SWITCH_ORG_NOT_FOUND = "No such organization id '{}'" +MSG_SWITCH_ORG_NOT_PERMITTED = "Not authorized to organization '{}'" diff --git a/graphistry/pygraphistry.py b/graphistry/pygraphistry.py index b012291c7..0a8dd7631 100644 --- a/graphistry/pygraphistry.py +++ b/graphistry/pygraphistry.py @@ -13,12 +13,24 @@ from . import util from . import bolt_util from .plotter import Plotter -from .util import setup_logger +from .util import in_databricks, setup_logger, in_ipython, make_iframe +from .exceptions import SsoRetrieveTokenTimeoutException + +from .messages import ( + MSG_REGISTER_MISSING_PASSWORD, + MSG_REGISTER_MISSING_USERNAME, + MSG_REGISTER_MISSING_PKEY_SECRET, + MSG_REGISTER_MISSING_PKEY_ID, + MSG_REGISTER_ENTER_SSO_LOGIN +) + + logger = setup_logger(__name__) ############################################################################### +SSO_GET_TOKEN_ELAPSE_SECONDS = 50 EnvVarNames = { "api_key": "GRAPHISTRY_API_KEY", @@ -53,6 +65,7 @@ "store_token_creds_in_memory": True, # Do not call API when all None "privacy": None, + "login_type": None } @@ -113,6 +126,15 @@ def authenticate(): PyGraphistry._check_key_and_version() PyGraphistry._is_authenticated = True + @staticmethod + def __reset_token_creds_in_memory(): + """Reset the token and creds in memory, used when switching hosts, switching register method""" + + PyGraphistry._config["api_key"] = None + PyGraphistry._is_authenticated = False + + + @staticmethod def not_implemented_thunk(): raise Exception("Must call login() first") @@ -123,10 +145,43 @@ def not_implemented_thunk(): def login(username, password, org_name=None, fail_silent=False): """Authenticate and set token for reuse (api=3). If token_refresh_ms (default: 10min), auto-refreshes token. By default, must be reinvoked within 24hr.""" + logger.debug("@PyGraphistry login : org_name :{} vs PyGraphistry.org_name() : {}".format(org_name, PyGraphistry.org_name())) + + if not org_name: + org_name = PyGraphistry.org_name() - if PyGraphistry._config["store_token_creds_in_memory"]: + if PyGraphistry._config['store_token_creds_in_memory']: PyGraphistry.relogin = lambda: PyGraphistry.login( - username, password, fail_silent + username, password, None, fail_silent + ) + + PyGraphistry._is_authenticated = False + token = ( + ArrowUploader( + server_base_path=PyGraphistry.protocol() + + "://" # noqa: W503 + + PyGraphistry.server(), # noqa: W503 + certificate_validation=PyGraphistry.certificate_validation(), + ) + .login(username, password, org_name) + .token + ) + + logger.debug("@PyGraphistry login After ArrowUploader.login: org_name :{} vs PyGraphistry.org_name() : {}".format(org_name, PyGraphistry.org_name())) + + PyGraphistry.api_token(token) + PyGraphistry._is_authenticated = True + + return PyGraphistry.api_token() + + @staticmethod + def pkey_login(personal_key_id, personal_key_secret, org_name=None, fail_silent=False): + """Authenticate with personal key/secret and set token for reuse (api=3). If token_refresh_ms (default: 10min), auto-refreshes token. + By default, must be reinvoked within 24hr.""" + + if PyGraphistry._config['store_token_creds_in_memory']: + PyGraphistry.relogin = lambda: PyGraphistry.pkey_login( + personal_key_id, personal_key_secret, org_name if org_name else PyGraphistry.org_name(), fail_silent ) PyGraphistry._is_authenticated = False @@ -137,7 +192,7 @@ def login(username, password, org_name=None, fail_silent=False): + PyGraphistry.server(), # noqa: W503 certificate_validation=PyGraphistry.certificate_validation(), ) - .login(username, password) + .pkey_login(personal_key_id, personal_key_secret, org_name) .token ) PyGraphistry.api_token(token) @@ -145,13 +200,167 @@ def login(username, password, org_name=None, fail_silent=False): return PyGraphistry.api_token() + @staticmethod + def sso_login(org_name=None, idp_name=None, sso_timeout=SSO_GET_TOKEN_ELAPSE_SECONDS): + """Authenticate with SSO and set token for reuse (api=3). + + :param org_name: Set login organization's name(slug). Defaults to user's personal organization. + :type org_name: Optional[str] + :param idp_name: Set sso login idp name. Default as None (for site-wide SSO / for the only idp record). + :type idp_name: Optional[str] + :param sso_timeout: Set sso login getting token timeout in seconds (blocking mode), set to None if non-blocking mode. Default as SSO_GET_TOKEN_ELAPSE_SECONDS. + :type sso_timeout: Optional[int] + :returns: None. + :rtype: None + + SSO Login logic. + + """ + + if PyGraphistry._config['store_token_creds_in_memory']: + PyGraphistry.relogin = lambda: PyGraphistry.sso_login( + org_name, idp_name, sso_timeout + ) + + PyGraphistry._is_authenticated = False + arrow_uploader = ArrowUploader( + server_base_path=PyGraphistry.protocol() + + "://" # noqa: W503 + + PyGraphistry.server(), # noqa: W503 + certificate_validation=PyGraphistry.certificate_validation(), + ).sso_login(org_name, idp_name) + + try: + if arrow_uploader.token: + PyGraphistry.api_token(arrow_uploader.token) + PyGraphistry._is_authenticated = True + arrow_uploader.token = None + return PyGraphistry.api_token() + except Exception: # required to log on + # print("required to log on") + PyGraphistry.sso_state(arrow_uploader.sso_state) + + auth_url = arrow_uploader.sso_auth_url + # print("auth_url : {}".format(auth_url)) + if auth_url and not PyGraphistry.api_token(): + PyGraphistry._handle_auth_url(auth_url, sso_timeout) + + + @staticmethod + def _handle_auth_url(auth_url, sso_timeout): + """Internal function to handle what to do with the auth_url + based on the client mode python/ipython console or notebook. + + :param auth_url: SSO auth url retrieved via API + :type auth_url: str + :param sso_timeout: Set sso login getting token timeout in seconds (blocking mode), set to None if non-blocking mode. Default as SSO_GET_TOKEN_ELAPSE_SECONDS. + :type sso_timeout: Optional[int] + :returns: None. + :rtype: None + + SSO Login logic. + + """ + + if in_ipython() or in_databricks(): # If run in notebook, just display the HTML + # from IPython.core.display import HTML + from IPython.display import display, HTML + display(HTML(f'Login SSO')) + print("Please click the above link to open browser to login") + print("Please close browser tab after SSO login to back to notebook") + # return HTML(make_iframe(auth_url, 20, extra_html=extra_html, override_html_style=override_html_style)) + else: + print("Please minimize browser after SSO login to back to pygraphistry") + + import webbrowser + input("Press Enter to open browser ...") + # open browser to auth_url + webbrowser.open(auth_url) + + if sso_timeout is not None: + time.sleep(1) + elapsed_time = 1 + token = None + + while True: + token, org_name = PyGraphistry._sso_get_token() + try: + if not token: + if elapsed_time % 10 == 1: + print("Waiting for token : {} seconds ...".format(sso_timeout - elapsed_time + 1)) + + time.sleep(1) + elapsed_time = elapsed_time + 1 + if elapsed_time > sso_timeout: + raise SsoRetrieveTokenTimeoutException("[SSO] Get token timeout") + else: + break + except SsoRetrieveTokenTimeoutException as toe: + logger.debug(toe, exc_info=1) + break + except Exception: + token = None + if token: + # set org_name to sso org + PyGraphistry._config['org_name'] = org_name + + print("Successfully get a token") + return PyGraphistry.api_token() + else: + return None + else: + print("Please run graphistry.sso_get_token() to complete the authentication") + + + @staticmethod + def sso_get_token(): + """ Get authentication token in SSO non-blocking mode""" + token, org_name = PyGraphistry._sso_get_token() + # set org_name to sso org + PyGraphistry._config['org_name'] = org_name + return token + + @staticmethod + def _sso_get_token(): + token = None + # get token from API using state + state = PyGraphistry.sso_state() + # print("_sso_get_token : {}".format(state)) + arrow_uploader = ArrowUploader( + server_base_path=PyGraphistry.protocol() + + "://" # noqa: W503 + + PyGraphistry.server(), # noqa: W503 + certificate_validation=PyGraphistry.certificate_validation(), + ).sso_get_token(state) + + try: + try: + token = arrow_uploader.token + org_name = arrow_uploader.org_name + except Exception: + pass + logger.debug("jwt token :{}".format(token)) + # print("jwt token :{}".format(token)) + PyGraphistry.api_token(token or PyGraphistry._config['api_token']) + # print("api_token() : {}".format(PyGraphistry.api_token())) + PyGraphistry._is_authenticated = True + token = PyGraphistry.api_token() + # print("api_token() : {}".format(token)) + return token, org_name + except: + # raise + pass + return None, None + @staticmethod def refresh(token=None, fail_silent=False): """Use self or provided JWT token to get a fresher one. If self token, internalize upon refresh.""" using_self_token = token is None + logger.debug("1. @PyGraphistry refresh, org_name: {}".format(PyGraphistry._config['org_name'])) try: if PyGraphistry.store_token_creds_in_memory(): logger.debug("JWT refresh via creds") + logger.debug("2. @PyGraphistry refresh :relogin") return PyGraphistry.relogin() logger.debug("JWT refresh via token") @@ -339,6 +548,8 @@ def register( username: Optional[str] = None, password: Optional[str] = None, token: Optional[str] = None, + personal_key_id: Optional[str] = None, + personal_key_secret: Optional[str] = None, server: Optional[str] = None, protocol: Optional[str] = None, api: Optional[Literal[1, 3]] = None, @@ -347,7 +558,10 @@ def register( token_refresh_ms: int = 10 * 60 * 1000, store_token_creds_in_memory: Optional[bool] = None, client_protocol_hostname: Optional[str] = None, - org_name: Optional[str] = None + org_name: Optional[str] = None, + idp_name: Optional[str] = None, + is_sso_login: Optional[bool] = False, + sso_timeout: Optional[int] = SSO_GET_TOKEN_ELAPSE_SECONDS ): """API key registration and server selection @@ -363,6 +577,10 @@ def register( :type password: Optional[str] :param token: Valid Account JWT token (2.0). Provide token, or username/password, but not both. :type token: Optional[str] + :param personal_key_id: Personal Key id for service account. + :type personal_key_id: Optional[str] + :param personal_key_secret: Personal Key secret for service account. + :type personal_key_secret: Optional[str] :param server: URL of the visualization server. :type server: Optional[str] :param protocol: Protocol to use for server uploaders, defaults to "https". @@ -383,9 +601,25 @@ def register( :type client_protocol_hostname: Optional[str] :param org_name: Set login organization's name(slug). Defaults to user's personal organization. :type org_name: Optional[str] + :param idp_name: Set sso login idp name. Default as None (for site-wide SSO / for the only idp record). + :type idp_name: Optional[str] + :param sso_timeout: Set sso login getting token timeout in seconds (blocking mode), set to None if non-blocking mode. Default as SSO_GET_TOKEN_ELAPSE_SECONDS. + :type sso_timeout: Optional[int] :returns: None. :rtype: None + **Example: Standard (2.0 api by org_name via SSO configured for site or for organization with only 1 IdP)** + :: + + import graphistry + graphistry.register(api=3, protocol='http', server='200.1.1.1', org_name="org-name", idp_name="idp-name") + + **Example: Standard (2.0 api by org_name via SSO IdP configured for an organization)** + :: + + import graphistry + graphistry.register(api=3, protocol='http', server='200.1.1.1', org_name="org-name") + **Example: Standard (2.0 api by username/password with org_name)** :: @@ -404,6 +638,12 @@ def register( import graphistry graphistry.register(api=3, protocol='http', server='200.1.1.1', token='abc') + **Example: Standard (by personal_key_id/personal_key_secret)** + :: + + import graphistry + graphistry.register(api=3, protocol='http', server='200.1.1.1', personal_key_id='ZD5872XKNF', personal_key_secret='SA0JJ2DTVT6LLO2S') + **Example: Remote browser to Graphistry-provided notebook server (2.0)** :: @@ -425,21 +665,48 @@ def register( PyGraphistry.client_protocol_hostname(client_protocol_hostname) PyGraphistry.certificate_validation(certificate_validation) PyGraphistry.store_token_creds_in_memory(store_token_creds_in_memory) + PyGraphistry.set_bolt_driver(bolt) + # Reset token creds + PyGraphistry.__reset_token_creds_in_memory() + if not (username is None) and not (password is None): PyGraphistry.login(username, password, org_name) - PyGraphistry.api_token(token or PyGraphistry._config['api_token']) - PyGraphistry.authenticate() - - PyGraphistry.set_bolt_driver(bolt) + PyGraphistry.api_token(token or PyGraphistry._config['api_token']) + PyGraphistry.authenticate() + elif (username is None and not (password is None)): + raise Exception(MSG_REGISTER_MISSING_USERNAME) + elif not (username is None) and password is None: + raise Exception(MSG_REGISTER_MISSING_PASSWORD) + elif not (personal_key_id is None) and not (personal_key_secret is None): + PyGraphistry.pkey_login(personal_key_id, personal_key_secret, org_name=org_name) + PyGraphistry.api_token(token or PyGraphistry._config['api_token']) + PyGraphistry.authenticate() + elif personal_key_id is None and not (personal_key_secret is None): + raise Exception(MSG_REGISTER_MISSING_PKEY_ID) + elif not (personal_key_id is None) and personal_key_secret is None: + raise Exception(MSG_REGISTER_MISSING_PKEY_SECRET) + elif not (token is None): + PyGraphistry.api_token(token or PyGraphistry._config['api_token']) + elif not (org_name is None) or is_sso_login: + print(MSG_REGISTER_ENTER_SSO_LOGIN) + PyGraphistry.sso_login(org_name, idp_name, sso_timeout=sso_timeout) + @staticmethod + def __check_login_type_to_reset_token_creds( + origin_login_type: str, + new_login_type: str, + ): + if origin_login_type != new_login_type: + PyGraphistry.__reset_token_creds_in_memory() + @staticmethod def privacy( - mode: Optional[str] = None, - notify: Optional[bool] = None, - invited_users: Optional[List] = None, - mode_action: Optional[str] = None, - message: Optional[str] = None, - ): + mode: Optional[str] = None, + notify: Optional[bool] = None, + invited_users: Optional[List] = None, + mode_action: Optional[str] = None, + message: Optional[str] = None + ): """Set global default sharing mode :param mode: Either "private" or "public" or "organization" @@ -848,10 +1115,10 @@ def neptune( @staticmethod def cosmos( - COSMOS_ACCOUNT: str = None, - COSMOS_DB: str = None, - COSMOS_CONTAINER: str = None, - COSMOS_PRIMARY_KEY: str = None, + COSMOS_ACCOUNT: Optional[str] = None, + COSMOS_DB: Optional[str] = None, + COSMOS_CONTAINER: Optional[str] = None, + COSMOS_PRIMARY_KEY: Optional[str] = None, gremlin_client: Any = None, ) -> Plotter: """Provide credentials as arguments, as environment variables, or by providing a gremlinpython client @@ -1822,6 +2089,13 @@ def _viz_url(info, url_params): extra, ) + @staticmethod + def _switch_org_url(org_name): + hostname = PyGraphistry._config["hostname"] + protocol = PyGraphistry._config["protocol"] + return "{}://{}/api/v2/o/{}/switch/".format(protocol, hostname, org_name) + + @staticmethod def _coerce_str(v): try: @@ -2000,6 +2274,53 @@ def layout_settings( scaling_ratio, ) + @staticmethod + def org_name(value=None): + """Set or get the org_name when register/login. + """ + + if value is None: + if 'org_name' in PyGraphistry._config: + return PyGraphistry._config['org_name'] + return None + + # setter, use switch_org instead + if 'org_name' not in PyGraphistry._config or value is not PyGraphistry._config['org_name']: + try: + PyGraphistry.switch_org(value.strip()) + # PyGraphistry._config['org_name'] = value.strip() + except: + raise Exception("Failed to switch organization") + + @staticmethod + def idp_name(value=None): + """Set or get the idp_name when register/login. + """ + + if value is None: + if 'idp_name' in PyGraphistry._config: + return PyGraphistry._config['idp_name'] + return None + + # setter + if 'idp_name' not in PyGraphistry._config or value is not PyGraphistry._config['idp_name']: + PyGraphistry._config['idp_name'] = value.strip() + + + @staticmethod + def sso_state(value=None): + """Set or get the sso_state when register/sso login. + """ + + if value is None: + if 'sso_state' in PyGraphistry._config: + return PyGraphistry._config['sso_state'] + return None + + # setter + if 'sso_state' not in PyGraphistry._config or value is not PyGraphistry._config['sso_state']: + PyGraphistry._config['sso_state'] = value.strip() + @staticmethod def scene_settings( menu: Optional[bool] = None, @@ -2021,19 +2342,65 @@ def scene_settings( ) scene_settings.__doc__ = Plotter().scene_settings.__doc__ + @staticmethod - def org_name(value=None): - """Set or get the org_name when register/login. + def personal_key_id(value: Optional[str] = None): + """Set or get the personal_key_id when register. """ if value is None: - if 'org_name' in PyGraphistry._config: - return PyGraphistry._config['org_name'] + if 'personal_key_id' in PyGraphistry._config: + return PyGraphistry._config['personal_key_id'] return None # setter - if 'org_name' not in PyGraphistry._config or value is not PyGraphistry._config['org_name']: + if 'personal_key_id' not in PyGraphistry._config or value is not PyGraphistry._config['personal_key_id']: + PyGraphistry._config['personal_key_id'] = value.strip() + + @staticmethod + def personal_key_secret(value: Optional[str] = None): + """Set or get the personal_key_secret when register. + """ + + if value is None: + if 'personal_key_secret' in PyGraphistry._config: + return PyGraphistry._config['personal_key_secret'] + return None + + # setter + if 'personal_key_secret' not in PyGraphistry._config or value is not PyGraphistry._config['personal_key']: + PyGraphistry._config['personal_key_secret'] = value.strip() + + @staticmethod + def switch_org(value): + # print(PyGraphistry._switch_org_url(value)) + response = requests.post( + PyGraphistry._switch_org_url(value), + data={'slug': value}, + headers={'Authorization': f'Bearer {PyGraphistry.api_token()}'}, + verify=PyGraphistry._config["certificate_validation"], + ) + result = PyGraphistry._handle_api_response(response) + + if result is True: PyGraphistry._config['org_name'] = value.strip() + logger.info("Switched to organization: {}".format(value.strip())) + else: # print the error message + raise Exception(result) + + @staticmethod + def _handle_api_response(response): + try: + json_response = response.json() + if json_response.get('status', None) == 'OK': + return True + else: + return json_response.get('message', '') + except: + logger.error('Error: %s', response, exc_info=True) + raise Exception("Unknown Error") + + client_protocol_hostname = PyGraphistry.client_protocol_hostname @@ -2041,6 +2408,7 @@ def org_name(value=None): server = PyGraphistry.server protocol = PyGraphistry.protocol register = PyGraphistry.register +sso_get_token = PyGraphistry.sso_get_token privacy = PyGraphistry.privacy login = PyGraphistry.login refresh = PyGraphistry.refresh @@ -2078,9 +2446,15 @@ def org_name(value=None): gsql = PyGraphistry.gsql layout_settings = PyGraphistry.layout_settings org_name = PyGraphistry.org_name +idp_name = PyGraphistry.idp_name +sso_state = PyGraphistry.sso_state scene_settings = PyGraphistry.scene_settings from_igraph = PyGraphistry.from_igraph from_cugraph = PyGraphistry.from_cugraph +personal_key_id = PyGraphistry.personal_key_id +personal_key_secret = PyGraphistry.personal_key_secret +switch_org = PyGraphistry.switch_org + class NumpyJSONEncoder(json.JSONEncoder): diff --git a/graphistry/tests/test_arrow_uploader.py b/graphistry/tests/test_arrow_uploader.py index 9e366c69b..e09c3e56a 100644 --- a/graphistry/tests/test_arrow_uploader.py +++ b/graphistry/tests/test_arrow_uploader.py @@ -289,3 +289,62 @@ def test_login_with_org_valid_org_name_not_member(self, mock_post): with pytest.raises(Exception): au.token + + @mock.patch('requests.post') + def test_sso_login_when_required_authentication(self, mock_post): + + mock_resp = self._mock_response( + json_data={ + 'auth_url': 'https://sso-idp-host/authorize?state=xxuixld', + 'state': 'xxuixld' + }) + mock_post.return_value = mock_resp + + au = ArrowUploader() + + with pytest.raises(Exception): + au.sso_login(org_name="mock-org", idp_name="mock-idp") + + with pytest.raises(Exception): + au.sso_state == 'xxuixld' + au.auth_url == 'https://sso-idp-host/authorize?state=xxuixld' + + @mock.patch('requests.post') + def test_sso_login_when_already_authenticated(self, mock_post): + + mock_resp = self._mock_response( + json_data={ + 'state': 'xxuixld' + }) + mock_post.return_value = mock_resp + + au = ArrowUploader() + + with pytest.raises(Exception): + au.sso_login(org_name="mock-org", idp_name="mock-idp") + + with pytest.raises(Exception): + assert au.sso_state == 'xxuixld' + + + @mock.patch('requests.post') + def test_sso_login_get_sso_token(self, mock_post): + + mock_resp = self._mock_response( + json_data={ + 'token': '123', + 'active_organization': { + "slug": "mock-org", + 'is_found': True, + 'is_member': True + } + }) + mock_post.return_value = mock_resp + + au = ArrowUploader() + + with pytest.raises(Exception): + au.sso_get_token(state='abcdwerd') + + with pytest.raises(Exception): + assert au.token == '123' diff --git a/graphistry/tests/test_ipython.py b/graphistry/tests/test_ipython.py index 856228457..6e8ac76fa 100644 --- a/graphistry/tests/test_ipython.py +++ b/graphistry/tests/test_ipython.py @@ -7,6 +7,7 @@ @patch("webbrowser.open") @patch("requests.post", return_value=Fake_Response()) class TestPlotterReturnValue(NoAuthTestCase): + @patch("graphistry.PlotterBase.in_ipython") def test_no_ipython(self, mock_in_ipython, mock_post, mock_open): mock_in_ipython.return_value = False @@ -19,5 +20,9 @@ def test_no_ipython(self, mock_in_ipython, mock_post, mock_open): @patch("graphistry.PlotterBase.in_ipython") def test_ipython(self, mock_in_ipython, mock_post, mock_open): mock_in_ipython.return_value = True + + # The setUpClass in NoAuthTestCase only run once, so, reset the _is_authenticated to True here + graphistry.pygraphistry.PyGraphistry._is_authenticated = True + widget = graphistry.bind(source="src", destination="dst").plot(triangleEdges) self.assertIsInstance(widget, IPython.core.display.HTML) diff --git a/graphistry/tests/test_pygraphistry.py b/graphistry/tests/test_pygraphistry.py index 6ae5348d7..126fb7d5d 100644 --- a/graphistry/tests/test_pygraphistry.py +++ b/graphistry/tests/test_pygraphistry.py @@ -1,8 +1,19 @@ # -*- coding: utf-8 -*- -import unittest +import unittest, pytest +from mock import patch + +from graphistry.pygraphistry import PyGraphistry +from graphistry.messages import ( + MSG_REGISTER_MISSING_PASSWORD, + MSG_REGISTER_MISSING_USERNAME, + MSG_REGISTER_MISSING_PKEY_SECRET, + MSG_REGISTER_MISSING_PKEY_ID, + MSG_SWITCH_ORG_SUCCESS, + MSG_SWITCH_ORG_NOT_FOUND, + MSG_SWITCH_ORG_NOT_PERMITTED +) -from graphistry import PyGraphistry # TODO mock requests for testing actual effectful code @@ -16,3 +27,95 @@ def test_overrides(self): assert PyGraphistry.store_token_creds_in_memory() is True PyGraphistry.register(store_token_creds_in_memory=False) assert PyGraphistry.store_token_creds_in_memory() is False + + +def test_register_with_only_username(capfd): + with pytest.raises(Exception) as exc_info: + PyGraphistry.register(username='only_username') + + assert str(exc_info.value) == MSG_REGISTER_MISSING_PASSWORD + + +def test_register_with_only_password(capfd): + with pytest.raises(Exception) as exc_info: + PyGraphistry.register(password='only_password') + + assert str(exc_info.value) == MSG_REGISTER_MISSING_USERNAME + + +def test_register_with_only_personal_key_id(capfd): + with pytest.raises(Exception) as exc_info: + PyGraphistry.register(personal_key_id='only_personal_key_id') + + assert str(exc_info.value) == MSG_REGISTER_MISSING_PKEY_SECRET + + +def test_register_with_only_personal_key_secret(capfd): + with pytest.raises(Exception) as exc_info: + PyGraphistry.register(personal_key_secret='only_personal_key_secret') + + assert str(exc_info.value) == MSG_REGISTER_MISSING_PKEY_ID + + +class FakeRequestResponse(object): + def __init__(self, response): + self.response = response + def raise_for_status(self): + pass + + def json(self): + return self.response + + +switch_org_success_response = { + "status": "OK", + "message": MSG_SWITCH_ORG_SUCCESS.format('success-org'), + "data": [] +} + + +org_not_exist_response = { + "status": "Failed", + "message": MSG_SWITCH_ORG_NOT_FOUND.format('not-exist-org'), + "data": [] +} + +org_not_permitted_response = { + "status": "Failed", + "message": MSG_SWITCH_ORG_NOT_PERMITTED.format('not-permitted-org'), + "data": [] +} + +# Print has been switch to logger.info +@patch("requests.post", return_value=FakeRequestResponse(switch_org_success_response)) +def test_switch_organization_success(mock_response, capfd): + PyGraphistry.org_name("success-org") + out, err = capfd.readouterr() + assert out == '' + + +@patch("requests.post", return_value=FakeRequestResponse(org_not_exist_response)) +def test_switch_organization_not_exist(mock_response, capfd): + org_name = "not-exist-org" + with pytest.raises(Exception) as exc_info: + PyGraphistry.org_name(org_name) + + assert str(exc_info.value) == "Failed to switch organization" + + # PyGraphistry.org_name("not-exist-org") + # out, err = capfd.readouterr() + # assert "Failed to switch organization" in out + + +@patch("requests.post", return_value=FakeRequestResponse(org_not_permitted_response)) +def test_switch_organization_not_permitted(mock_response, capfd): + org_name = "not-permitted-org" + with pytest.raises(Exception) as exc_info: + PyGraphistry.org_name(org_name) + + assert str(exc_info.value) == "Failed to switch organization" + + + # PyGraphistry.org_name("not-permitted-org") + # out, err = capfd.readouterr() + # assert "Failed to switch organization" in out diff --git a/setup.py b/setup.py index a7047fc2e..dcebb8002 100755 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def unique_flatten_dict(d): ] stubs = [ - 'pandas-stubs', 'types-requests' + 'pandas-stubs', 'types-requests', 'ipython' ] dev_extras = {