Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Binary file not shown.
156,031 changes: 156,031 additions & 0 deletions G-DeepLearning.AI/Q-Text-Embedding-GoogleCloud/L5-Applications/so_database_app.csv

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
from dotenv import load_dotenv
import json
import base64
from google.auth.transport.requests import Request
from google.oauth2.service_account import Credentials
import functools
import time
from concurrent.futures import ThreadPoolExecutor
from tqdm.auto import tqdm
import math
from vertexai.language_models import TextEmbeddingModel
import numpy as np
import matplotlib.pyplot as plt
import mplcursors

def authenticate():
return "DLAI-credentials", "DLAI-PROJECT"
#Load .env
load_dotenv()

#Decode key and store in .JSON
SERVICE_ACCOUNT_KEY_STRING_B64 = os.getenv('SERVICE_ACCOUNT_KEY')
SERVICE_ACCOUNT_KEY_BYTES_B64 = SERVICE_ACCOUNT_KEY_STRING_B64.encode("ascii")
SERVICE_ACCOUNT_KEY_STRING_BYTES = base64.b64decode(SERVICE_ACCOUNT_KEY_BYTES_B64)
SERVICE_ACCOUNT_KEY_STRING = SERVICE_ACCOUNT_KEY_STRING_BYTES.decode("ascii")

SERVICE_ACCOUNT_KEY = json.loads(SERVICE_ACCOUNT_KEY_STRING)


# Create credentials based on key from service account
# Make sure your account has the roles listed in the Google Cloud Setup section
credentials = Credentials.from_service_account_info(
SERVICE_ACCOUNT_KEY,
scopes=['https://www.googleapis.com/auth/cloud-platform'])

if credentials.expired:
credentials.refresh(Request())

#Set project ID accoridng to environment variable
PROJECT_ID = os.getenv('PROJECT_ID')

return credentials, PROJECT_ID


def generate_batches(sentences, batch_size = 5):
for i in range(0, len(sentences), batch_size):
yield sentences[i : i + batch_size]

def encode_texts_to_embeddings(sentences):
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")
try:
embeddings = model.get_embeddings(sentences)
return [embedding.values for embedding in embeddings]
except Exception:
return [None for _ in range(len(sentences))]

def encode_text_to_embedding_batched(sentences, api_calls_per_second = 0.33, batch_size = 5):
# Generates batches and calls embedding API

embeddings_list = []

# Prepare the batches using a generator
batches = generate_batches(sentences, batch_size)

seconds_per_job = 1 / api_calls_per_second

with ThreadPoolExecutor() as executor:
futures = []
for batch in tqdm(
batches, total = math.ceil(len(sentences) / batch_size), position=0
):
futures.append(
executor.submit(functools.partial(encode_texts_to_embeddings), batch)
)
time.sleep(seconds_per_job)

for future in futures:
embeddings_list.extend(future.result())

is_successful = [
embedding is not None for sentence, embedding in zip(sentences, embeddings_list)
]
embeddings_list_successful = np.squeeze(
np.stack([embedding for embedding in embeddings_list if embedding is not None])
)
return embeddings_list_successful

def clusters_2D(x_values, y_values, labels, kmeans_labels):
fig, ax = plt.subplots()
scatter = ax.scatter(x_values,
y_values,
c = kmeans_labels,
cmap='Set1',
alpha=0.5,
edgecolors='k',
s = 40) # Change the denominator as per n_clusters

# Create a mplcursors object to manage the data point interaction
cursor = mplcursors.cursor(scatter, hover=True)

#axes
ax.set_title('Embedding clusters visualization in 2D') # Add a title
ax.set_xlabel('X_1') # Add x-axis label
ax.set_ylabel('X_2') # Add y-axis label

# Define how each annotation should look
@cursor.connect("add")
def on_add(sel):
sel.annotation.set_text(labels.category[sel.target.index])
sel.annotation.get_bbox_patch().set(facecolor='white', alpha=0.95) # Set annotation's background color
sel.annotation.set_fontsize(14)

plt.show()
Loading