Skip to content

Commit b3f2e7a

Browse files
authored
chore(qa): refactor text qa to add support for additional models and fix issues (#332)
* chore(qa): refactor text qa to add support for additional models and fix issues * chore(rag): fix type in graphql schema preventing deployment if qa used in same stack * chore(graphql): fix space in graphql schema * chore(clean): remove useless parts * chore(qa): fix filtering and move logic * chore(debug): test image qa fix logging
1 parent 8ce1242 commit b3f2e7a

File tree

19 files changed

+531
-134
lines changed

19 files changed

+531
-134
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .bedrock import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base import ModelAdapter
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
13+
import os
14+
from langchain.callbacks.base import BaseCallbackHandler
15+
from langchain.prompts.prompt import PromptTemplate
16+
17+
18+
class ModelAdapter:
19+
def __init__(self, callback=None, modality='Text', model_kwargs={}):
20+
self.model_kwargs = model_kwargs
21+
self.modality = modality
22+
23+
self.callback_handler = callback
24+
25+
self.llm = self.get_llm(model_kwargs)
26+
27+
def get_llm(self, model_kwargs={}):
28+
raise ValueError("llm must be implemented")
29+
30+
def get_embeddings_model(self, model_kwargs={}):
31+
raise ValueError("embeddings must be implemented")
32+
33+
def get_prompt(self):
34+
35+
template = """
36+
37+
The following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
38+
39+
{context}
40+
41+
Question: {question}"""
42+
43+
prompt_template = PromptTemplate(template=template, input_variables=["context", "question"])
44+
45+
return prompt_template
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .claude import *
2+
from .titan import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
13+
import boto3
14+
15+
from langchain_community.llms import Bedrock
16+
from langchain_community.chat_models import BedrockChat
17+
from langchain.prompts.prompt import PromptTemplate
18+
19+
from ..base import ModelAdapter
20+
from ..registry import registry
21+
22+
23+
class BedrockClaudeAdapter(ModelAdapter):
24+
def __init__(self, model_id, *args, **kwargs):
25+
self.model_id = model_id
26+
27+
super().__init__(*args, **kwargs)
28+
29+
def get_llm(self, model_kwargs={}):
30+
bedrock = boto3.client('bedrock-runtime')
31+
32+
params = {}
33+
if "temperature" in model_kwargs:
34+
params["temperature"] = model_kwargs["temperature"]
35+
if "top_p" in model_kwargs:
36+
params["top_p"] = model_kwargs["top_p"]
37+
if "max_tokens_to_sample" in model_kwargs:
38+
params["max_tokens_to_sample"] = model_kwargs["max_tokens_to_sample"]
39+
if "stop_sequences" in model_kwargs:
40+
params["stop_sequences"] = model_kwargs["stop_sequences"]
41+
if "top_k" in model_kwargs:
42+
params["top_k"] = model_kwargs["top_k"]
43+
44+
params["anthropic_version"] = "bedrock-2023-05-31"
45+
46+
kwargs = {
47+
"client": bedrock,
48+
"model_id": self.model_id,
49+
"model_kwargs": params,
50+
"streaming": False
51+
}
52+
53+
if self.callback_handler:
54+
kwargs["callbacks"] = self.callback_handler
55+
kwargs["streaming"] = model_kwargs.get("streaming", False)
56+
57+
return Bedrock(
58+
**kwargs
59+
)
60+
61+
def get_prompt(self):
62+
template = """
63+
64+
Human: Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
65+
66+
{context}
67+
68+
Question: {question}
69+
70+
Assistant:"""
71+
72+
return PromptTemplate(
73+
template=template, input_variables=["context", "question"]
74+
)
75+
76+
# For claude v3, at the moment we need to use BedrockChat
77+
class BedrockClaudev3Adapter(ModelAdapter):
78+
def __init__(self, model_id, *args, **kwargs):
79+
self.model_id = model_id
80+
81+
super().__init__(*args, **kwargs)
82+
83+
def get_llm(self, model_kwargs={}):
84+
bedrock = boto3.client('bedrock-runtime')
85+
86+
params = {}
87+
if "temperature" in model_kwargs:
88+
params["temperature"] = model_kwargs["temperature"]
89+
if "top_p" in model_kwargs:
90+
params["top_p"] = model_kwargs["top_p"]
91+
if "max_tokens" in model_kwargs:
92+
params["max_tokens"] = model_kwargs["max_tokens"]
93+
if "stop_sequences" in model_kwargs:
94+
params["stop_sequences"] = model_kwargs["stop_sequences"]
95+
if "top_k" in model_kwargs:
96+
params["top_k"] = model_kwargs["top_k"]
97+
98+
params["anthropic_version"] = "bedrock-2023-05-31"
99+
100+
kwargs = {
101+
"client": bedrock,
102+
"model_id": self.model_id,
103+
"model_kwargs": params,
104+
"streaming": False
105+
}
106+
107+
if self.callback_handler:
108+
kwargs["callbacks"] = self.callback_handler
109+
kwargs["streaming"] = model_kwargs.get("streaming", False)
110+
111+
return BedrockChat(
112+
**kwargs
113+
)
114+
115+
def get_prompt(self):
116+
template = """
117+
118+
Human: Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
119+
120+
{context}
121+
122+
Question: {question}
123+
124+
Assistant:"""
125+
126+
return PromptTemplate(
127+
template=template, input_variables=["context", "question"]
128+
)
129+
130+
131+
# Register the adapter
132+
registry.register(r"^Bedrock.anthropic.claude-v2*", BedrockClaudeAdapter)
133+
registry.register(r"^Bedrock.anthropic.claude-instant*", BedrockClaudeAdapter)
134+
registry.register(r"^Bedrock.anthropic.claude-3*", BedrockClaudev3Adapter)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
13+
import boto3
14+
from langchain.prompts.prompt import PromptTemplate
15+
16+
from langchain_community.llms import Bedrock
17+
from langchain_community.embeddings import BedrockEmbeddings
18+
19+
from ..base import ModelAdapter
20+
from ..registry import registry
21+
22+
23+
class BedrockTitanAdapter(ModelAdapter):
24+
def __init__(self, model_id, *args, **kwargs):
25+
self.model_id = model_id
26+
27+
super().__init__(*args, **kwargs)
28+
29+
def get_llm(self, model_kwargs={}):
30+
bedrock = boto3.client('bedrock-runtime')
31+
32+
params = {}
33+
if "temperature" in model_kwargs:
34+
params["temperature"] = model_kwargs["temperature"]
35+
if "topP" in model_kwargs:
36+
params["topP"] = model_kwargs["topP"]
37+
if "maxTokenCount" in model_kwargs:
38+
params["maxTokenCount"] = model_kwargs["maxTokens"]
39+
if "stopSequences" in model_kwargs:
40+
params["stopSequences"] = model_kwargs["stopSequences"]
41+
42+
kwargs = {
43+
"client": bedrock,
44+
"model_id": self.model_id,
45+
"model_kwargs": params,
46+
"streaming": False
47+
}
48+
49+
if self.callback_handler:
50+
kwargs["callbacks"] = self.callback_handler
51+
kwargs["streaming"] = model_kwargs.get("streaming", False)
52+
53+
return Bedrock(
54+
**kwargs
55+
)
56+
57+
def get_embeddings_model(self, model_kwargs={}):
58+
bedrock = boto3.client('bedrock-runtime')
59+
60+
return BedrockEmbeddings(client=bedrock, model_id=self.model_id)
61+
62+
def get_prompt(self):
63+
template = """Human: The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
64+
65+
{context}
66+
67+
Question: {question}
68+
69+
Assistant:"""
70+
71+
return PromptTemplate(
72+
template=template, input_variables=["context", "question"]
73+
)
74+
75+
# Register the adapter
76+
registry.register(r"^Bedrock.amazon.titan-t*", BedrockTitanAdapter)
77+
registry.register(r"^Bedrock.amazon.titan-e*", BedrockTitanAdapter)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .index import AdapterRegistry
2+
3+
registry = AdapterRegistry()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
# with the License. A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
10+
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
11+
# and limitations under the License.
12+
#
13+
import re
14+
15+
class AdapterRegistry:
16+
def __init__(self):
17+
# The registry is a dictionary where:
18+
# Keys are compiled regular expressions
19+
# Values are model IDs
20+
self.registry = {}
21+
22+
def register(self, regex, model_id):
23+
# Compiles the regex and stores it in the registry
24+
self.registry[re.compile(regex)] = model_id
25+
26+
def get_adapter(self, model):
27+
# Iterates over the registered regexes
28+
for regex, adapter in self.registry.items():
29+
# If a match is found, returns the associated model ID
30+
if regex.match(model):
31+
return adapter
32+
# If no match is found, returns None
33+
return None
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .text_generation_llm_selector import get_llm, get_max_tokens, get_embeddings_llm,get_bedrock_fm
1+
from .text_generation_llm_selector import get_max_tokens,get_bedrock_fm

lambda/aws-qa-appsync-opensearch/question_answering/src/llms/text_generation_llm_selector.py

+13-47
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# and limitations under the License.
1212
#
1313
from aiohttp import ClientError
14-
from langchain.llms.bedrock import Bedrock
14+
from langchain_community.llms import Bedrock
1515
from langchain_community.embeddings import BedrockEmbeddings
1616
import os
1717
import boto3
1818
from .helper import get_credentials
19+
from .types import Provider, BedrockModel, MAX_TOKENS_MAP
1920

2021
from aws_lambda_powertools import Logger, Tracer, Metrics
2122
from aws_lambda_powertools.utilities.typing import LambdaContext
@@ -26,41 +27,6 @@
2627
tracer = Tracer(service="QUESTION_ANSWERING")
2728
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")
2829

29-
30-
31-
def get_llm(callbacks=None,model_id="anthropic.claude-v2:1"):
32-
bedrock = boto3.client('bedrock-runtime')
33-
34-
params = {
35-
"max_tokens_to_sample": 600,
36-
"temperature": 0,
37-
"top_k": 250,
38-
"top_p": 1,
39-
"stop_sequences": ["\\n\\nHuman:"],
40-
}
41-
42-
kwargs = {
43-
"client": bedrock,
44-
"model_id": model_id,
45-
"model_kwargs": params,
46-
"streaming": False
47-
}
48-
49-
if callbacks:
50-
kwargs["callbacks"] = callbacks
51-
kwargs["streaming"] = True
52-
53-
return Bedrock(**kwargs)
54-
55-
def get_embeddings_llm(model_id,modality):
56-
bedrock = boto3.client('bedrock-runtime')
57-
validation_status=validate_model_id_in_bedrock(model_id,modality)
58-
if(validation_status['status']):
59-
return BedrockEmbeddings(client=bedrock, model_id=model_id)
60-
else:
61-
return None
62-
63-
6430
def get_bedrock_fm(model_id,modality):
6531
bedrock_client = boto3.client('bedrock-runtime')
6632
validation_status= validate_model_id_in_bedrock(model_id,modality)
@@ -71,17 +37,17 @@ def get_bedrock_fm(model_id,modality):
7137
logger.error(f"reason ::{validation_status['message']} ")
7238
return None
7339

74-
75-
76-
#TODO -add max token based on model id
77-
def get_max_tokens(model_id):
78-
match model_id:
79-
case "anthropic.claude-v2:1":
80-
return 200000
81-
case "anthropic.claude-3-sonnet-20240229-v1:0":
82-
return 200000
83-
case _:
84-
return 4096
40+
41+
def get_max_tokens(model):
42+
43+
# if model_id is not provided, we default to Claude v2
44+
if not model:
45+
return MAX_TOKENS_MAP[BedrockModel.ANTHROPIC_CLAUDE_V2_1]
46+
try:
47+
return MAX_TOKENS_MAP[model]
48+
except:
49+
logger.error('unable to get the max tokens for the specified model')
50+
return -1
8551

8652

8753
def validate_model_id_in_bedrock(model_id,modality):

0 commit comments

Comments
 (0)