Skip to content

Commit

Permalink
Fix: Qdrant query count not optional (#972)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjkindel authored Jul 13, 2024
1 parent b8f1128 commit 0a8e178
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 71 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

### Fixed
- Parameter `count` for `QdrantVectorStoreDriver.query` now optional as per documentation.

## [0.28.2] - 2024-07-12
### Fixed
Expand Down
96 changes: 26 additions & 70 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vecto

```python
import os
from griptape.artifacts import BaseArtifact
from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

Expand All @@ -40,16 +39,11 @@ artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai")
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results = vector_store_driver.query(
"creativity",
count=3,
namespace="griptape"
)
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))

```

### Griptape Cloud Knowledge Base
Expand All @@ -58,7 +52,6 @@ The [GriptapeCloudKnowledgeBaseVectorStoreDriver](../../reference/griptape/drive

```python
import os
from griptape.artifacts import BaseArtifact
from griptape.drivers import GriptapeCloudKnowledgeBaseVectorStoreDriver


Expand All @@ -68,12 +61,11 @@ gt_cloud_knowledge_base_id = os.environ["GRIPTAPE_CLOUD_KB_ID"]

vector_store_driver = GriptapeCloudKnowledgeBaseVectorStoreDriver(api_key=gt_cloud_api_key, knowledge_base_id=gt_cloud_knowledge_base_id)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))

```

### Pinecone
Expand All @@ -86,50 +78,28 @@ The [PineconeVectorStoreDriver](../../reference/griptape/drivers/vector/pinecone
Here is an example of how the Driver can be used to load and query information in a Pinecone cluster:

```python
import os
import hashlib
import json
from urllib.request import urlopen
import os
from griptape.drivers import PineconeVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

def load_data(driver: PineconeVectorStoreDriver) -> None:
response = urlopen(
"https://raw.githubusercontent.com/wedeploy-examples/"
"supermarket-web-example/master/products.json"
)

for product in json.loads(response.read()):
driver.upsert_text(
product["description"],
vector_id=hashlib.md5(product["title"].encode()).hexdigest(),
meta={
"title": product["title"],
"description": product["description"],
"type": product["type"],
"price": product["price"],
"rating": product["rating"],
},
namespace="supermarket-products",
)

# Initialize an Embedding Driver
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

vector_store_driver = PineconeVectorStoreDriver(
api_key=os.environ["PINECONE_API_KEY"],
environment=os.environ["PINECONE_ENVIRONMENT"],
index_name=os.environ['PINECONE_INDEX_NAME'],
index_name=os.environ["PINECONE_INDEX_NAME"],
embedding_driver=embedding_driver,
)

load_data(vector_store_driver)
# Load Artifacts from the web
artifacts = WebLoader(max_tokens=100).load("https://www.griptape.ai")

results = vector_store_driver.query(
"fruit",
count=3,
filter={"price": {"$lte": 15}, "rating": {"$gte": 4}},
namespace="supermarket-products",
)
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -175,7 +145,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -227,7 +197,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -298,7 +268,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -341,7 +311,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -388,7 +358,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down Expand Up @@ -450,7 +420,7 @@ vector_store_driver.upsert_text_artifacts(
}
)

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand All @@ -468,54 +438,40 @@ Here is an example of how the Driver can be used to query information in a Qdran

```python
import os
from griptape.drivers import QdrantVectorStoreDriver, HuggingFaceHubEmbeddingDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.drivers import QdrantVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

# Set up environment variables
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
host = os.environ["QDRANT_CLUSTER_ENDPOINT"]
huggingface_token = os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"]
api_key = os.environ["QDRANT_CLUSTER_API_KEY"]

# Initialize HuggingFace Embedding Driver
embedding_driver = HuggingFaceHubEmbeddingDriver(
api_token=huggingface_token,
model=embedding_model_name,
tokenizer=HuggingFaceTokenizer(model=embedding_model_name, max_output_tokens=512),
)
# Initialize an Embedding Driver.
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

# Initialize Qdrant Vector Store Driver
vector_store_driver = QdrantVectorStoreDriver(
url=host,
collection_name="griptape",
content_payload_key="content",
embedding_driver=embedding_driver,
api_key=os.environ["QDRANT_CLUSTER_API_KEY"],
api_key=api_key,
)

# Load Artifacts from the web
artifacts = WebLoader().load("https://www.griptape.ai")

# Encode text to get embeddings
embeddings = embedding_driver.embed_text_artifact(artifacts[0])

# Recreate Qdrant collection
vector_store_driver.client.recreate_collection(
collection_name=vector_store_driver.collection_name,
vectors_config={
"size": len(embeddings),
"size": 1536,
"distance": vector_store_driver.distance
},
)

# Upsert vector into Qdrant
vector_store_driver.upsert_vector(
vector=embeddings,
vector_id=str(artifacts[0].id),
content=artifacts[0].value
)
# Upsert Artifacts into the Vector Store Driver
[vector_store_driver.upsert_text_artifact(a, namespace="griptape") for a in artifacts]

results =vector_store_driver.query(query="What is griptape?")
results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/vector/qdrant_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def query(
query_vector = self.embedding_driver.embed_string(query)

# Create a search request
results = self.client.search(collection_name=self.collection_name, query_vector=query_vector, limit=count)
request = {"collection_name": self.collection_name, "query_vector": query_vector, "limit": count}
request = {k: v for k, v in request.items() if v is not None}
results = self.client.search(**request)

# Convert results to QueryResult objects
query_results = [
Expand Down

0 comments on commit 0a8e178

Please sign in to comment.