Skip to content

Commit

Permalink
Merge pull request ScrapeGraphAI#135 from S4mpl3r/feature
Browse files Browse the repository at this point in the history
  • Loading branch information
VinciGit00 authored May 3, 2024
2 parents 98dec36 + 819cbcd commit 2abe05a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/groq/smart_scraper_groq_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
"embeddings": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"model": "openai",
},
"headless": False
}
Expand Down
88 changes: 86 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from abc import ABC, abstractmethod
from typing import Optional

from ..models import OpenAI, Gemini, Ollama, AzureOpenAI, HuggingFace, Groq, Bedrock
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings

from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI


class AbstractGraph(ABC):
Expand Down Expand Up @@ -43,7 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])

# Set common configuration parameters
Expand Down Expand Up @@ -165,6 +170,85 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
else:
raise ValueError(
"Model provided by the configuration not supported")

def _create_default_embedder(self) -> object:
"""
Create an embedding model instance based on the chosen llm model.
Returns:
object: An instance of the embedding model client.
Raises:
ValueError: If the model is not supported.
"""

if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)

return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, HuggingFace):
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, Bedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")

def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
Args:
embedder_config (dict): Configuration parameters for the embedding model.
Returns:
object: An instance of the embedding model client.
Raises:
KeyError: If the model is not supported.
"""

# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])

elif "azure" in embedder_config["model"]:
return AzureOpenAIEmbeddings()

elif "ollama" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["ollama"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return OllamaEmbeddings(**embedder_config)

elif "hugging_face" in embedder_config["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return HuggingFaceHubEmbeddings(model=embedder_config["model"])

elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")

def get_state(self, key=None) -> dict:
"""""
Expand Down
26 changes: 1 addition & 25 deletions scrapegraphai/nodes/rag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,7 @@ def execute(self, state: dict) -> dict:
if self.verbose:
print("--- (updated chunks metadata) ---")

# check if embedder_model is provided, if not use llm_model
embedding_model = self.embedder_model if self.embedder_model else self.llm_model

if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = embedding_model._lc_kwargs
# remove streaming and temperature
params.pop("streaming", None)
params.pop("temperature", None)

embeddings = OllamaEmbeddings(**params)
elif isinstance(embedding_model, HuggingFace):
embeddings = HuggingFaceHubEmbeddings(model=embedding_model.model)
elif isinstance(embedding_model, Bedrock):
embeddings = BedrockEmbeddings(
client=None, model_id=embedding_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
embeddings = self.embedder_model

retriever = FAISS.from_documents(
chunked_docs, embeddings).as_retriever()
Expand Down

0 comments on commit 2abe05a

Please sign in to comment.