Skip to content

Commit d622eca

Browse files
authored
Merge pull request #244 from awslabs/chore/update-get-bedrock-client
chore: update bedrock client retrieval and aoss index creation
2 parents aacbc8a + 20694d9 commit d622eca

File tree

11 files changed

+43
-199
lines changed

11 files changed

+43
-199
lines changed

apidocs/classes/LangchainCommonLayer.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ LangchainCommonLayer allows developers to instantiate a llm client adapter on be
99
**`Example`**
1010

1111
```ts
12+
import boto3
1213
from genai_core.adapters.registry import registry
13-
from genai_core.clients import get_bedrock_client
1414

1515
adapter = registry.get_adapter(f"{provider}.{model_id}")
16-
bedrock_client = get_bedrock_client()
16+
bedrock_client = boto3.client('bedrock-runtime')
1717
```
1818
1919
## Hierarchy

lambda/aws-qa-appsync-opensearch/question_answering/src/llms/text_generation_llm_selector.py

+2-36
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,9 @@
2525
tracer = Tracer(service="QUESTION_ANSWERING")
2626
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2727

28-
sts_client = boto3.client("sts")
29-
30-
aws_region = boto3.Session().region_name
31-
32-
def get_bedrock_client(service_name="bedrock-runtime"):
33-
config = {}
34-
bedrock_config = config.get("bedrock", {})
35-
bedrock_enabled = bedrock_config.get("enabled", False)
36-
if not bedrock_enabled:
37-
print("bedrock not enabled")
38-
return None
39-
40-
bedrock_config_data = {"service_name": service_name}
41-
region_name = bedrock_config.get("region")
42-
endpoint_url = bedrock_config.get("endpointUrl")
43-
role_arn = bedrock_config.get("roleArn")
44-
45-
if region_name:
46-
bedrock_config_data["region_name"] = region_name
47-
if endpoint_url:
48-
bedrock_config_data["endpoint_url"] = endpoint_url
49-
50-
if role_arn:
51-
assumed_role_object = sts_client.assume_role(
52-
RoleArn=role_arn,
53-
RoleSessionName="AssumedRoleSession",
54-
)
55-
56-
credentials = assumed_role_object["Credentials"]
57-
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
58-
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
59-
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]
60-
61-
return boto3.client(**bedrock_config_data)
6228

6329
def get_llm(callbacks=None):
64-
bedrock = get_bedrock_client(service_name="bedrock-runtime")
30+
bedrock = boto3.client('bedrock-runtime')
6531

6632
params = {
6733
"max_tokens_to_sample": 600,
@@ -85,7 +51,7 @@ def get_llm(callbacks=None):
8551
return Bedrock(**kwargs)
8652

8753
def get_embeddings_llm():
88-
bedrock = get_bedrock_client(service_name="bedrock-runtime")
54+
bedrock = boto3.client('bedrock-runtime')
8955
return BedrockEmbeddings(client=bedrock, model_id="amazon.titan-embed-text-v1")
9056

9157
def get_max_tokens():

lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/helpers/opensearch_helper.py

+3-35
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,6 @@
2323
tracer = Tracer(service="INGESTION_EMBEDDING_JOB")
2424
metrics = Metrics(namespace="ingestion_pipeline", service="INGESTION_EMBEDDING_JOB")
2525

26-
aws_region = boto3.Session().region_name
27-
sts_client = boto3.client("sts")
28-
29-
def get_bedrock_client(service_name="bedrock-runtime"):
30-
config = {}
31-
bedrock_config = config.get("bedrock", {})
32-
bedrock_enabled = bedrock_config.get("enabled", False)
33-
if not bedrock_enabled:
34-
print("bedrock not enabled")
35-
return None
36-
37-
bedrock_config_data = {"service_name": service_name}
38-
region_name = bedrock_config.get("region")
39-
endpoint_url = bedrock_config.get("endpointUrl")
40-
role_arn = bedrock_config.get("roleArn")
41-
42-
if region_name:
43-
bedrock_config_data["region_name"] = region_name
44-
if endpoint_url:
45-
bedrock_config_data["endpoint_url"] = endpoint_url
46-
47-
if role_arn:
48-
assumed_role_object = sts_client.assume_role(
49-
RoleArn=role_arn,
50-
RoleSessionName="AssumedRoleSession",
51-
)
52-
53-
credentials = assumed_role_object["Credentials"]
54-
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
55-
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
56-
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]
57-
58-
return boto3.client(**bedrock_config_data)
5926

6027
@tracer.capture_method
6128
def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch:
@@ -72,13 +39,14 @@ def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tu
7239

7340
def process_shard(shard, os_index_name, os_domain_ep, os_http_auth) -> int:
7441
print(f'Starting process_shard of {len(shard)} chunks.')
75-
bedrock_client = get_bedrock_client()
42+
bedrock_client = boto3.client('bedrock-runtime')
7643
embeddings = BedrockEmbeddings(
7744
client=bedrock_client,
7845
model_id="amazon.titan-embed-text-v1")
46+
opensearch_url = os_domain_ep if os_domain_ep.startswith("https://") else f"https://{os_domain_ep}"
7947
docsearch = OpenSearchVectorSearch(index_name=os_index_name,
8048
embedding_function=embeddings,
81-
opensearch_url=f"https://{os_domain_ep}",
49+
opensearch_url=opensearch_url,
8250
http_auth=os_http_auth,
8351
use_ssl = True,
8452
verify_certs = True,

lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py

+25-76
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from helpers.update_ingestion_status import updateIngestionJobStatus
2323
from langchain_community.embeddings import BedrockEmbeddings
2424
from helpers.s3inmemoryloader import S3TxtFileLoaderInMemory
25-
from opensearchpy import OpenSearch, RequestsHttpConnection
25+
from opensearchpy import RequestsHttpConnection
2626
from langchain_community.vectorstores import OpenSearchVectorSearch
2727
from langchain.text_splitter import RecursiveCharacterTextSplitter
2828
import multiprocessing as mp
@@ -39,38 +39,7 @@
3939
aws_region = boto3.Session().region_name
4040
session = boto3.session.Session()
4141
credentials = session.get_credentials()
42-
sts_client = boto3.client("sts")
4342

44-
def get_bedrock_client(service_name="bedrock-runtime"):
45-
config = {}
46-
bedrock_config = config.get("bedrock", {})
47-
bedrock_enabled = bedrock_config.get("enabled", False)
48-
if not bedrock_enabled:
49-
print("bedrock not enabled")
50-
return None
51-
52-
bedrock_config_data = {"service_name": service_name}
53-
region_name = bedrock_config.get("region")
54-
endpoint_url = bedrock_config.get("endpointUrl")
55-
role_arn = bedrock_config.get("roleArn")
56-
57-
if region_name:
58-
bedrock_config_data["region_name"] = region_name
59-
if endpoint_url:
60-
bedrock_config_data["endpoint_url"] = endpoint_url
61-
62-
if role_arn:
63-
assumed_role_object = sts_client.assume_role(
64-
RoleArn=role_arn,
65-
RoleSessionName="AssumedRoleSession",
66-
)
67-
68-
credentials = assumed_role_object["Credentials"]
69-
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
70-
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
71-
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]
72-
73-
return boto3.client(**bedrock_config_data)
7443

7544
opensearch_secret_id = os.environ['OPENSEARCH_SECRET_ID']
7645
bucket_name = os.environ['OUTPUT_BUCKET']
@@ -88,7 +57,7 @@ def get_bedrock_client(service_name="bedrock-runtime"):
8857
INDEX_FILE="index_file"
8958

9059
def process_documents_in_es(index_exists, shards, http_auth):
91-
bedrock_client = get_bedrock_client()
60+
bedrock_client = boto3.client('bedrock-runtime')
9261
embeddings = BedrockEmbeddings(client=bedrock_client)
9362

9463
if index_exists is False:
@@ -136,52 +105,32 @@ def process_documents_in_es(index_exists, shards, http_auth):
136105
os_http_auth=http_auth)
137106

138107
def process_documents_in_aoss(index_exists, shards, http_auth):
108+
# Reference: https://python.langchain.com/docs/integrations/vectorstores/opensearch#using-aoss-amazon-opensearch-service-serverless
109+
bedrock_client = boto3.client('bedrock-runtime')
110+
embeddings = BedrockEmbeddings(client=bedrock_client)
111+
112+
shard_start_index = 0
139113
if index_exists is False:
140-
vector_db = OpenSearch(
141-
hosts = [{'host': opensearch_domain.replace("https://", ""), 'port': 443}],
142-
http_auth = http_auth,
143-
use_ssl = True,
144-
verify_certs = True,
145-
connection_class = RequestsHttpConnection
114+
OpenSearchVectorSearch.from_documents(
115+
shards[0],
116+
embeddings,
117+
opensearch_url=opensearch_domain,
118+
http_auth=http_auth,
119+
timeout=300,
120+
use_ssl=True,
121+
verify_certs=True,
122+
connection_class=RequestsHttpConnection,
123+
index_name=opensearch_index,
124+
engine="faiss",
146125
)
147-
index_body = {
148-
'settings': {
149-
"index.knn": True
150-
},
151-
"mappings": {
152-
"properties": {
153-
"vector_field": {
154-
"type": "knn_vector",
155-
"dimension": 1536,
156-
"method": {
157-
"engine": "nmslib",
158-
"space_type": "cosinesimil",
159-
"name": "hnsw",
160-
"parameters": {},
161-
}
162-
},
163-
"id": {
164-
"type": "text",
165-
"fields": {"keyword": {"type": "keyword", "ignore_above": 256}},
166-
},
167-
}
168-
}
169-
}
170-
response = vector_db.indices.create(opensearch_index, body=index_body)
171-
print(response)
126+
# we now need to start the loop below for the second shard
127+
shard_start_index = 1
172128

173-
print(f"index={opensearch_index} Adding Documents")
174-
bedrock_client = get_bedrock_client()
175-
embeddings = BedrockEmbeddings(client=bedrock_client, model_id="amazon.titan-embed-text-v1")
176-
docsearch = OpenSearchVectorSearch(index_name=opensearch_index,
177-
embedding_function=embeddings,
178-
opensearch_url=opensearch_domain,
179-
http_auth=http_auth,
180-
use_ssl = True,
181-
verify_certs = True,
182-
connection_class = RequestsHttpConnection)
183-
for shard in shards:
184-
docsearch.add_documents(documents=shard)
129+
for shard in shards[shard_start_index:]:
130+
results = process_shard(shard=shard,
131+
os_index_name=opensearch_index,
132+
os_domain_ep=opensearch_domain,
133+
os_http_auth=http_auth)
185134

186135
@logger.inject_lambda_context(log_event=True)
187136
@tracer.capture_lambda_handler

lambda/aws-summarization-appsync-stepfn/summary_generator/lambda.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
transformed_bucket_name = os.environ["ASSET_BUCKET_NAME"]
3838
chain_type = os.environ["SUMMARY_LLM_CHAIN_TYPE"]
3939

40-
aws_region = boto3.Session().region_name
41-
4240
params = {
4341
"max_tokens_to_sample": 4000,
4442
"temperature": 0,
@@ -47,11 +45,7 @@
4745
"stop_sequences": ["\\n\\nHuman:"],
4846
}
4947

50-
bedrock_client = boto3.client(
51-
service_name='bedrock-runtime',
52-
region_name=aws_region,
53-
endpoint_url=f'https://bedrock-runtime.{aws_region}.amazonaws.com'
54-
)
48+
bedrock_client = boto3.client('bedrock-runtime')
5549

5650
@logger.inject_lambda_context(log_event=True)
5751
@tracer.capture_lambda_handler

layers/langchain-common-layer/python/genai_core/adapters/bedrock/ai21_j2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13-
import genai_core.clients
13+
import boto3
1414
from langchain.llms import Bedrock
1515
from langchain.prompts.prompt import PromptTemplate
1616

@@ -25,7 +25,7 @@ def __init__(self, model_id, *args, **kwargs):
2525
super().__init__(*args, **kwargs)
2626

2727
def get_llm(self, model_kwargs={}):
28-
bedrock = genai_core.clients.get_bedrock_client()
28+
bedrock = boto3.client('bedrock-runtime')
2929

3030
params = {}
3131
if "temperature" in model_kwargs:

layers/langchain-common-layer/python/genai_core/adapters/bedrock/claude.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13-
import genai_core.clients
13+
import boto3
1414

1515
from langchain.llms import Bedrock
1616
from langchain.prompts.prompt import PromptTemplate
@@ -26,7 +26,7 @@ def __init__(self, model_id, *args, **kwargs):
2626
super().__init__(*args, **kwargs)
2727

2828
def get_llm(self, model_kwargs={}):
29-
bedrock = genai_core.clients.get_bedrock_client()
29+
bedrock = boto3.client('bedrock-runtime')
3030

3131
params = {}
3232
if "temperature" in model_kwargs:

layers/langchain-common-layer/python/genai_core/adapters/bedrock/cohere.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13-
import genai_core.clients
13+
import boto3
1414

1515
from langchain.llms import Bedrock
1616
from langchain.prompts.prompt import PromptTemplate
@@ -26,7 +26,7 @@ def __init__(self, model_id, *args, **kwargs):
2626
super().__init__(*args, **kwargs)
2727

2828
def get_llm(self, model_kwargs={}):
29-
bedrock = genai_core.clients.get_bedrock_client()
29+
bedrock = boto3.client('bedrock-runtime')
3030

3131
params = {}
3232
if "temperature" in model_kwargs:

layers/langchain-common-layer/python/genai_core/adapters/bedrock/titan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
1111
# and limitations under the License.
1212
#
13-
import genai_core.clients
13+
import boto3
1414
from langchain.prompts.prompt import PromptTemplate
1515

1616
from langchain.llms import Bedrock
@@ -26,7 +26,7 @@ def __init__(self, model_id, *args, **kwargs):
2626
super().__init__(*args, **kwargs)
2727

2828
def get_llm(self, model_kwargs={}):
29-
bedrock = genai_core.clients.get_bedrock_client()
29+
bedrock = boto3.client('bedrock-runtime')
3030

3131
params = {}
3232
if "temperature" in model_kwargs:

layers/langchain-common-layer/python/genai_core/clients.py

-33
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import openai
1616
from botocore.config import Config
1717

18-
sts_client = boto3.client("sts")
19-
2018
def get_openai_client():
2119
api_key = os.environ['OPEN_API_KEY']
2220
if not api_key:
@@ -33,34 +31,3 @@ def get_sagemaker_client():
3331
client = boto3.client("sagemaker-runtime", config=config)
3432

3533
return client
36-
37-
38-
def get_bedrock_client(service_name="bedrock-runtime"):
39-
config = {}
40-
bedrock_config = config.get("bedrock", {})
41-
bedrock_enabled = bedrock_config.get("enabled", False)
42-
if not bedrock_enabled:
43-
return None
44-
45-
bedrock_config_data = {"service_name": service_name}
46-
region_name = bedrock_config.get("region")
47-
endpoint_url = bedrock_config.get("endpointUrl")
48-
role_arn = bedrock_config.get("roleArn")
49-
50-
if region_name:
51-
bedrock_config_data["region_name"] = region_name
52-
if endpoint_url:
53-
bedrock_config_data["endpoint_url"] = endpoint_url
54-
55-
if role_arn:
56-
assumed_role_object = sts_client.assume_role(
57-
RoleArn=role_arn,
58-
RoleSessionName="AssumedRoleSession",
59-
)
60-
61-
credentials = assumed_role_object["Credentials"]
62-
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
63-
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
64-
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]
65-
66-
return boto3.client(**bedrock_config_data)

0 commit comments

Comments
 (0)