Skip to content

Commit

Permalink
refactor(prune_weighted_edges_df_and_relabel_nodes): move to relevant…
Browse files Browse the repository at this point in the history
… file
  • Loading branch information
lmeyerov committed Sep 15, 2024
1 parent 54d8ffb commit 02ca1ad
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 53 deletions.
48 changes: 0 additions & 48 deletions graphistry/feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,54 +1821,6 @@ def scale(self, X=None, y=None, return_pipeline=False, *args, **kwargs):
return X, y


# ######################################################################################################################
#
#
#
# ######################################################################################################################


def prune_weighted_edges_df_and_relabel_nodes(
wdf: pd.DataFrame, scale: float = 0.1, index_to_nodes_dict: Optional[Dict] = None
) -> pd.DataFrame:
"""Prune the weighted edge DataFrame so to return high fidelity similarity scores.
:param wdf: weighted edge DataFrame gotten via UMAP
:param scale: lower values means less edges > (max - scale * std)
:param index_to_nodes_dict: dict of index to node name;
remap src/dst values if provided
:return: pd.DataFrame
"""
# we want to prune edges, so we calculate some statistics
desc = wdf.describe()
eps = 1e-3

mean = desc[config.WEIGHT]["mean"]
std = desc[config.WEIGHT]["std"]
max_val = desc[config.WEIGHT]["max"] + eps
min_val = desc[config.WEIGHT]["min"] - eps
thresh = np.max(
[max_val - scale, min_val]
) # if std =0 we add eps so we still have scale in the equation

logger.info(
f" -- edge weights: mean({mean:.2f}), "
f"std({std:.2f}), max({max_val}), "
f"min({min_val:.2f}), thresh({thresh:.2f})"
)
wdf2 = wdf[
wdf[config.WEIGHT] >= thresh
] # adds eps so if scale = 0, we have small window/wiggle room
logger.info(
" -- Pruning weighted edge DataFrame "
f"from {len(wdf):,} to {len(wdf2):,} edges."
)
if index_to_nodes_dict is not None:
wdf2[config.SRC] = wdf2[config.SRC].map(index_to_nodes_dict)
wdf2[config.DST] = wdf2[config.DST].map(index_to_nodes_dict)
return wdf2


# ###########################################################################
#
# Fast Memoize
Expand Down
51 changes: 46 additions & 5 deletions graphistry/umap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,52 @@ def umap_graph_to_weighted_edges(umap_graph, engine, is_legacy, cfg=config):
_weighted_edges_df = pd.DataFrame(
{src: coo.row, dst: coo.col, weight_col: coo.data}
)
elif (engine == "cuml") and not is_legacy:
_weighted_edges_df = pd.DataFrame(
{src: coo.get().row, dst: coo.get().col, weight_col: coo.get().data}
)
return _weighted_edges_df
)


##############################################################################


def prune_weighted_edges_df_and_relabel_nodes(
wdf: DataFrameLike, scale: float = 0.1, index_to_nodes_dict: Optional[Dict] = None
) -> pd.DataFrame:
"""Prune the weighted edge DataFrame so to return high fidelity similarity scores.
:param wdf: weighted edge DataFrame gotten via UMAP
:param scale: lower values means less edges > (max - scale * std)
:param index_to_nodes_dict: dict of index to node name;
remap src/dst values if provided
:return: pd.DataFrame
"""
# we want to prune edges, so we calculate some statistics
desc = wdf.describe()
eps = 1e-3

mean = desc[config.WEIGHT]["mean"]
std = desc[config.WEIGHT]["std"]
max_val = desc[config.WEIGHT]["max"] + eps
min_val = desc[config.WEIGHT]["min"] - eps
thresh = np.max(
[max_val - scale, min_val]
) # if std =0 we add eps so we still have scale in the equation

logger.info(
f" -- edge weights: mean({mean:.2f}), "
f"std({std:.2f}), max({max_val}), "
f"min({min_val:.2f}), thresh({thresh:.2f})"
)
wdf2 = wdf[
wdf[config.WEIGHT] >= thresh
] # adds eps so if scale = 0, we have small window/wiggle room
logger.info(
" -- Pruning weighted edge DataFrame "
f"from {len(wdf):,} to {len(wdf2):,} edges."
)
if index_to_nodes_dict is not None:
wdf2[config.SRC] = wdf2[config.SRC].map(index_to_nodes_dict)
wdf2[config.DST] = wdf2[config.DST].map(index_to_nodes_dict)
return wdf2



class UMAPMixin(MIXIN_BASE):
Expand Down

0 comments on commit 02ca1ad

Please sign in to comment.