Skip to content

Commit

Permalink
Merge pull request #1062 from facebookexperimental/feature/visualizat…
Browse files Browse the repository at this point in the history
…ion-interface

visualizer interface
  • Loading branch information
shivkanthb authored Oct 8, 2024
2 parents e0458e7 + 2deb356 commit 70f4560
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/src/robyn/visualization/allocator_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class AllocatorVisualizer(BaseVisualizer):
def __init__(self, allocator_data: Dict[str, Any]):
super().__init__()
self.allocator_data = allocator_data

def plot_allocator(self) -> plt.Figure:
"""
Plot allocator's output.
Returns:
plt.Figure: The generated figure.
"""
pass
9 changes: 9 additions & 0 deletions python/src/robyn/visualization/base_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import matplotlib.pyplot as plt

class BaseVisualizer:
def __init__(self):
pass

def _setup_plot(self):
# Common plot setup logic
pass
41 changes: 41 additions & 0 deletions python/src/robyn/visualization/input_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class InputVisualizer(BaseVisualizer):
def __init__(self, input_data: Dict[str, Any]):
super().__init__()
self.input_data = input_data

def plot_adstock(self) -> plt.Figure:
"""
Create example plots for adstock hyperparameters.
Returns:
plt.Figure: The generated figure.
"""
fig, ax = plt.subplots()
# Add plotting logic here
return fig

def plot_saturation(self) -> plt.Figure:
"""
Create example plots for saturation hyperparameters.
Returns:
plt.Figure: The generated figure.
"""
fig, ax = plt.subplots()
# Add plotting logic here
return fig

def plot_spend_exposure_fit(self) -> Dict[str, plt.Figure]:
"""
Check spend exposure fit if available.
Returns:
Dict[str, plt.Figure]: A dictionary of generated figures.
"""
figures = {}
# Add plotting logic here
return figures
49 changes: 49 additions & 0 deletions python/src/robyn/visualization/model_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class ModelVisualizer(BaseVisualizer):
def __init__(self, model_data: Dict[str, Any]):
super().__init__()
self.model_data = model_data

def plot_moo_distribution(self) -> plt.Figure:
"""
Plot MOO (multi-objective optimization) distribution.
Returns:
plt.Figure: The generated figure.
"""
pass

def plot_moo_cloud(self) -> plt.Figure:
"""
Plot MOO (multi-objective optimization) cloud.
Returns:
plt.Figure: The generated figure.
"""
pass

def plot_ts_validation(self) -> plt.Figure:
"""
Plot time-series validation.
Returns:
plt.Figure: The generated figure.
"""
pass

def plot_onepager(self, input_collect: Dict[str, Any], output_collect: Dict[str, Any], select_model: str) -> Dict[str, plt.Figure]:
"""
Generate one-pager plots for a selected model.
Args:
input_collect (Dict[str, Any]): The input collection data.
output_collect (Dict[str, Any]): The output collection data.
select_model (str): The selected model identifier.
Returns:
Dict[str, plt.Figure]: A dictionary of generated figures.
"""
pass
26 changes: 26 additions & 0 deletions python/src/robyn/visualization/response_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class ResponseVisualizer(BaseVisualizer):
def __init__(self, response_data: Dict[str, Any]):
super().__init__()
self.response_data = response_data

def plot_response(self) -> plt.Figure:
"""
Plot response curves.
Returns:
plt.Figure: The generated figure.
"""
pass

def plot_marginal_response(self) -> plt.Figure:
"""
Plot marginal response curves.
Returns:
plt.Figure: The generated figure.
"""
pass
52 changes: 52 additions & 0 deletions python/src/robyn/visualization/robyn_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .input_visualizer import InputVisualizer
from .model_visualizer import ModelVisualizer
from .allocator_visualizer import AllocatorVisualizer
from .response_visualizer import ResponseVisualizer

class RobynVisualizer:
def __init__(self):
self.input_visualizer = None
self.model_visualizer = None
self.allocator_visualizer = None
self.response_visualizer = None

def set_input_data(self, input_data: Dict[str, Any]):
self.input_visualizer = InputVisualizer(input_data)

def set_model_data(self, model_data: Dict[str, Any]):
self.model_visualizer = ModelVisualizer(model_data)

def set_allocator_data(self, allocator_data: Dict[str, Any]):
self.allocator_visualizer = AllocatorVisualizer(allocator_data)

def set_response_data(self, response_data: Dict[str, Any]):
self.response_visualizer = ResponseVisualizer(response_data)

def plot_adstock(self) -> plt.Figure:
return self.input_visualizer.plot_adstock()

def plot_saturation(self) -> plt.Figure:
return self.input_visualizer.plot_saturation()

def plot_moo_distribution(self) -> plt.Figure:
return self.model_visualizer.plot_moo_distribution()

def plot_moo_cloud(self) -> plt.Figure:
return self.model_visualizer.plot_moo_cloud()

def plot_ts_validation(self) -> plt.Figure:
return self.model_visualizer.plot_ts_validation()

def plot_onepager(self, input_collect: Dict[str, Any], output_collect: Dict[str, Any], select_model: str) -> Dict[str, plt.Figure]:
return self.model_visualizer.plot_onepager(input_collect, output_collect, select_model)

def plot_allocator(self) -> plt.Figure:
return self.allocator_visualizer.plot_allocator()

def plot_response(self) -> plt.Figure:
return self.response_visualizer.plot_response()

def plot_marginal_response(self) -> plt.Figure:
return self.response_visualizer.plot_marginal_response()

0 comments on commit 70f4560

Please sign in to comment.