-
Notifications
You must be signed in to change notification settings - Fork 96
feat: support bedrock reasoning content response #1074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
75373c1
adad05f
9e8247f
ade7b91
26e2d5b
5f3930a
dcc8144
5a77091
205689e
92e76c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,51 @@ | |
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream" | ||
corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" | ||
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" | ||
"github.com/fatih/structs" | ||
openaigo "github.com/openai/openai-go" | ||
openAIconstant "github.com/openai/openai-go/shared/constant" | ||
"k8s.io/utils/ptr" | ||
|
||
"github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" | ||
"github.com/envoyproxy/ai-gateway/internal/apischema/openai" | ||
tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" | ||
) | ||
|
||
// CustomChatCompletionMessage embeds the original and adds our ExtraFields. | ||
// This is where the custom marshaling logic will live. | ||
type CustomChatCompletionMessage struct { | ||
openaigo.ChatCompletionMessage | ||
ExtraFields map[string]interface{} `json:"-"` | ||
} | ||
|
||
// CustomChatCompletionChoice shadows the original Message field with our custom type. | ||
type CustomChatCompletionChoice struct { | ||
openaigo.ChatCompletionChoice | ||
Message CustomChatCompletionMessage `json:"message"` | ||
} | ||
|
||
// CustomChatCompletion shadows the original Choices slice with our custom type. | ||
type CustomChatCompletion struct { | ||
openaigo.ChatCompletion | ||
Choices []CustomChatCompletionChoice `json:"choices"` | ||
} | ||
|
||
// MarshalJSON implements a custom marshaler for CustomChatCompletionMessage. | ||
// It merges the standard fields with the contents of ExtraFields. | ||
func (c CustomChatCompletionMessage) MarshalJSON() ([]byte, error) { | ||
// 1. Directly convert the embedded struct to a map using the library. | ||
// This respects all the `json` tags on the struct's fields. | ||
tempMap := structs.Map(c.ChatCompletionMessage) | ||
|
||
// 2. Iterate through your ExtraFields and merge them into the map. | ||
for key, value := range c.ExtraFields { | ||
tempMap[key] = value | ||
} | ||
|
||
// 3. Marshal the final, merged map into a JSON byte slice. | ||
return json.Marshal(tempMap) | ||
} | ||
|
||
// NewChatCompletionOpenAIToAWSBedrockTranslator implements [Factory] for OpenAI to AWS Bedrock translation. | ||
func NewChatCompletionOpenAIToAWSBedrockTranslator(modelNameOverride string) OpenAIChatCompletionTranslator { | ||
return &openAIToAWSBedrockTranslatorV1ChatCompletion{modelNameOverride: modelNameOverride} | ||
|
@@ -82,6 +120,14 @@ | |
bedrockReq.InferenceConfig.StopSequences = stopSequence | ||
} | ||
|
||
// Handle Anthropic vendor fields if present. Currently only supports thinking fields. | ||
if openAIReq.AnthropicVendorFields != nil && openAIReq.Thinking != nil { | ||
if bedrockReq.AdditionalModelRequestFields == nil { | ||
bedrockReq.AdditionalModelRequestFields = make(map[string]interface{}) | ||
} | ||
bedrockReq.AdditionalModelRequestFields["thinking"] = openAIReq.Thinking | ||
} | ||
|
||
// Convert Chat Completion messages. | ||
err = o.openAIMessageToBedrockMessage(openAIReq, &bedrockReq) | ||
if err != nil { | ||
|
@@ -467,21 +513,20 @@ | |
|
||
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) bedrockToolUseToOpenAICalls( | ||
toolUse *awsbedrock.ToolUseBlock, | ||
) *openai.ChatCompletionMessageToolCallParam { | ||
) openaigo.ChatCompletionMessageToolCall { | ||
if toolUse == nil { | ||
return nil | ||
return openaigo.ChatCompletionMessageToolCall{} | ||
} | ||
arguments, err := json.Marshal(toolUse.Input) | ||
if err != nil { | ||
return nil | ||
return openaigo.ChatCompletionMessageToolCall{} | ||
} | ||
return &openai.ChatCompletionMessageToolCallParam{ | ||
ID: &toolUse.ToolUseID, | ||
Function: openai.ChatCompletionMessageToolCallFunctionParam{ | ||
return openaigo.ChatCompletionMessageToolCall{ | ||
ID: toolUse.ToolUseID, | ||
Function: openaigo.ChatCompletionMessageToolCallFunction{ | ||
Name: toolUse.Name, | ||
Arguments: string(arguments), | ||
}, | ||
Type: openai.ChatCompletionMessageToolCallTypeFunction, | ||
} | ||
} | ||
|
||
|
@@ -582,40 +627,50 @@ | |
if err = json.NewDecoder(body).Decode(&bedrockResp); err != nil { | ||
return nil, nil, tokenUsage, fmt.Errorf("failed to unmarshal body: %w", err) | ||
} | ||
openAIResp := &openai.ChatCompletionResponse{ | ||
Object: "chat.completion", | ||
Choices: make([]openai.ChatCompletionResponseChoice, 0), | ||
|
||
openAIResp := CustomChatCompletion{ | ||
Choices: make([]CustomChatCompletionChoice, 0), | ||
} | ||
|
||
// Convert token usage. | ||
if bedrockResp.Usage != nil { | ||
tokenUsage = LLMTokenUsage{ | ||
InputTokens: uint32(bedrockResp.Usage.InputTokens), //nolint:gosec | ||
OutputTokens: uint32(bedrockResp.Usage.OutputTokens), //nolint:gosec | ||
TotalTokens: uint32(bedrockResp.Usage.TotalTokens), //nolint:gosec | ||
} | ||
openAIResp.Usage = openai.ChatCompletionResponseUsage{ | ||
TotalTokens: bedrockResp.Usage.TotalTokens, | ||
PromptTokens: bedrockResp.Usage.InputTokens, | ||
CompletionTokens: bedrockResp.Usage.OutputTokens, | ||
openAIResp.Usage = openaigo.CompletionUsage{ | ||
TotalTokens: int64(bedrockResp.Usage.TotalTokens), | ||
PromptTokens: int64(bedrockResp.Usage.InputTokens), | ||
CompletionTokens: int64(bedrockResp.Usage.OutputTokens), | ||
} | ||
} | ||
|
||
// AWS Bedrock does not support N(multiple choices) > 0, so there could be only one choice. | ||
choice := openai.ChatCompletionResponseChoice{ | ||
Index: (int64)(0), | ||
Message: openai.ChatCompletionResponseChoiceMessage{ | ||
Role: bedrockResp.Output.Message.Role, | ||
choice := CustomChatCompletionChoice{ | ||
Message: CustomChatCompletionMessage{ | ||
ChatCompletionMessage: openaigo.ChatCompletionMessage{ | ||
Role: openAIconstant.Assistant(bedrockResp.Output.Message.Role), | ||
}, | ||
ExtraFields: make(map[string]interface{}), | ||
}, | ||
FinishReason: o.bedrockStopReasonToOpenAIStopReason(bedrockResp.StopReason), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we set the finishReason inplace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for Index |
||
} | ||
choice.Index = (int64)(0) | ||
choice.FinishReason = string(o.bedrockStopReasonToOpenAIStopReason(bedrockResp.StopReason)) | ||
|
||
for _, output := range bedrockResp.Output.Message.Content { | ||
if toolCall := o.bedrockToolUseToOpenAICalls(output.ToolUse); toolCall != nil { | ||
choice.Message.ToolCalls = []openai.ChatCompletionMessageToolCallParam{*toolCall} | ||
} else if output.Text != nil { | ||
// For the converse response the assumption is that there is only one text content block, we take the first one. | ||
if choice.Message.Content == nil { | ||
choice.Message.Content = output.Text | ||
switch { | ||
case output.ToolUse != nil: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
toolCall := o.bedrockToolUseToOpenAICalls(output.ToolUse) | ||
choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCall) | ||
|
||
case output.Text != nil: | ||
// We expect only one text content block in the response. | ||
if choice.Message.Content == "" { | ||
choice.Message.Content = *output.Text | ||
} | ||
case output.ReasoningContent != nil && output.ReasoningContent.ReasoningText != nil: | ||
choice.Message.ExtraFields["reasoning_content"] = *output.ReasoningContent | ||
} | ||
} | ||
openAIResp.Choices = append(openAIResp.Choices, choice) | ||
|
@@ -627,7 +682,7 @@ | |
headerMutation = &extprocv3.HeaderMutation{} | ||
setContentLength(headerMutation, mut.Body) | ||
if span != nil { | ||
span.RecordResponse(openAIResp) | ||
Check failure on line 685 in internal/extproc/translator/openai_awsbedrock.go
|
||
} | ||
return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, tokenUsage, nil | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
switch to a pointer return?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we just dereference the ptr right after returning, what do u think is best to do?