Skip to content

Commit

Permalink
Merge pull request ScrapeGraphAI#2 from shkamboj1/pre/beta
Browse files Browse the repository at this point in the history
Pre/beta
  • Loading branch information
shkamboj1 authored May 4, 2024
2 parents d277b34 + fd59f28 commit d05093a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
63 changes: 63 additions & 0 deletions examples/huggingfacehub/smart_scraper_huggingfacehub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""

import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings




## required environment variable in .env
#HUGGINGFACEHUB_API_TOKEN
load_dotenv()

HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
# ************************************************
# Initialize the model instances
# ************************************************

repo_id = "mistralai/Mistral-7B-Instruct-v0.2"

llm_model_instance = HuggingFaceEndpoint(
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
)




embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
)

# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************

graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}

smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)

result = smart_scraper_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))


11 changes: 10 additions & 1 deletion scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def _set_model_token(self, llm):
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")

elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
try:
self.model_token = models_tokens['mistral'][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")


def _create_llm(self, llm_config: dict, chat=False) -> object:
Expand Down Expand Up @@ -181,7 +188,6 @@ def _create_default_embedder(self) -> object:
Raises:
ValueError: If the model is not supported.
"""

if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
Expand Down Expand Up @@ -216,6 +222,9 @@ def _create_embedder(self, embedder_config: dict) -> object:
Raises:
KeyError: If the model is not supported.
"""

if 'model_instance' in embedder_config:
return embedder_config['model_instance']

# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
Expand Down
3 changes: 3 additions & 0 deletions scrapegraphai/helpers/models_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,8 @@
"mistral.mistral-large-2402-v1:0": 32768,
"cohere.embed-english-v3": 512,
"cohere.embed-multilingual-v3": 512
},
"mistral": {
"mistralai/Mistral-7B-Instruct-v0.2": 32000
}
}
2 changes: 2 additions & 0 deletions scrapegraphai/nodes/rag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def execute(self, state: dict) -> dict:
if self.verbose:
print("--- (updated chunks metadata) ---")

# check if embedder_model is provided, if not use llm_model
self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model
embeddings = self.embedder_model

retriever = FAISS.from_documents(
Expand Down

0 comments on commit d05093a

Please sign in to comment.