mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 02:27:11 +00:00
415 lines
12 KiB
Go
415 lines
12 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"encoding/base64"
|
||
"errors"
|
||
"fmt"
|
||
"hash/crc32"
|
||
"io"
|
||
"net/http"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"github.com/tidwall/sjson"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
)
|
||
|
||
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
|
||
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
|
||
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
|
||
func (s *GatewayService) handleBedrockStreamingResponse(
|
||
ctx context.Context,
|
||
resp *http.Response,
|
||
c *gin.Context,
|
||
account *Account,
|
||
startTime time.Time,
|
||
model string,
|
||
) (*streamingResult, error) {
|
||
w := c.Writer
|
||
flusher, ok := w.(http.Flusher)
|
||
if !ok {
|
||
return nil, errors.New("streaming not supported")
|
||
}
|
||
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("X-Accel-Buffering", "no")
|
||
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||
c.Header("x-request-id", v)
|
||
}
|
||
|
||
usage := &ClaudeUsage{}
|
||
var firstTokenMs *int
|
||
clientDisconnected := false
|
||
|
||
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
|
||
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
|
||
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
|
||
// 我们使用 EventStream decoder 来正确解析。
|
||
decoder := newBedrockEventStreamDecoder(resp.Body)
|
||
|
||
type decodeEvent struct {
|
||
payload []byte
|
||
err error
|
||
}
|
||
events := make(chan decodeEvent, 16)
|
||
done := make(chan struct{})
|
||
sendEvent := func(ev decodeEvent) bool {
|
||
select {
|
||
case events <- ev:
|
||
return true
|
||
case <-done:
|
||
return false
|
||
}
|
||
}
|
||
var lastReadAt atomic.Int64
|
||
lastReadAt.Store(time.Now().UnixNano())
|
||
|
||
go func() {
|
||
defer close(events)
|
||
for {
|
||
payload, err := decoder.Decode()
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
return
|
||
}
|
||
_ = sendEvent(decodeEvent{err: err})
|
||
return
|
||
}
|
||
lastReadAt.Store(time.Now().UnixNano())
|
||
if !sendEvent(decodeEvent{payload: payload}) {
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
defer close(done)
|
||
|
||
streamInterval := time.Duration(0)
|
||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||
}
|
||
var intervalTicker *time.Ticker
|
||
if streamInterval > 0 {
|
||
intervalTicker = time.NewTicker(streamInterval)
|
||
defer intervalTicker.Stop()
|
||
}
|
||
var intervalCh <-chan time.Time
|
||
if intervalTicker != nil {
|
||
intervalCh = intervalTicker.C
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case ev, ok := <-events:
|
||
if !ok {
|
||
if !clientDisconnected {
|
||
flusher.Flush()
|
||
}
|
||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||
}
|
||
if ev.err != nil {
|
||
if clientDisconnected {
|
||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||
}
|
||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||
}
|
||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
|
||
}
|
||
|
||
// payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据)
|
||
sseData := extractBedrockChunkData(ev.payload)
|
||
if sseData == nil {
|
||
continue
|
||
}
|
||
|
||
if firstTokenMs == nil {
|
||
ms := int(time.Since(startTime).Milliseconds())
|
||
firstTokenMs = &ms
|
||
}
|
||
|
||
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||
// 同时移除该字段避免透传给客户端
|
||
sseData = transformBedrockInvocationMetrics(sseData)
|
||
|
||
// 解析 SSE 事件数据提取 usage
|
||
s.parseSSEUsagePassthrough(string(sseData), usage)
|
||
|
||
// 确定 SSE event type
|
||
eventType := gjson.GetBytes(sseData, "type").String()
|
||
|
||
// 写入标准 SSE 格式
|
||
if !clientDisconnected {
|
||
var writeErr error
|
||
if eventType != "" {
|
||
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
|
||
} else {
|
||
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
|
||
}
|
||
if writeErr != nil {
|
||
clientDisconnected = true
|
||
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
|
||
} else {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
|
||
case <-intervalCh:
|
||
lastRead := time.Unix(0, lastReadAt.Load())
|
||
if time.Since(lastRead) < streamInterval {
|
||
continue
|
||
}
|
||
if clientDisconnected {
|
||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||
}
|
||
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||
if s.rateLimitService != nil {
|
||
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||
}
|
||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||
}
|
||
}
|
||
}
|
||
|
||
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
|
||
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
|
||
func extractBedrockChunkData(payload []byte) []byte {
|
||
b64 := gjson.GetBytes(payload, "bytes").String()
|
||
if b64 == "" {
|
||
return nil
|
||
}
|
||
decoded, err := base64.StdEncoding.DecodeString(b64)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
return decoded
|
||
}
|
||
|
||
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
|
||
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
|
||
//
|
||
// Bedrock Invoke 返回的 message_delta 事件可能包含:
|
||
//
|
||
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
|
||
//
|
||
// 转换为:
|
||
//
|
||
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
|
||
func transformBedrockInvocationMetrics(data []byte) []byte {
|
||
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
|
||
if !metrics.Exists() || !metrics.IsObject() {
|
||
return data
|
||
}
|
||
|
||
// 移除 Bedrock 特有字段
|
||
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
|
||
|
||
// 如果已有标准 usage 字段,不覆盖
|
||
if gjson.GetBytes(data, "usage").Exists() {
|
||
return data
|
||
}
|
||
|
||
// 转换 camelCase → snake_case 写入 usage
|
||
inputTokens := metrics.Get("inputTokenCount")
|
||
outputTokens := metrics.Get("outputTokenCount")
|
||
if inputTokens.Exists() {
|
||
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
|
||
}
|
||
if outputTokens.Exists() {
|
||
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
|
||
}
|
||
|
||
return data
|
||
}
|
||
|
||
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
|
||
// EventStream 帧格式:
|
||
//
|
||
// [total_byte_length: 4 bytes]
|
||
// [headers_byte_length: 4 bytes]
|
||
// [prelude_crc: 4 bytes]
|
||
// [headers: variable]
|
||
// [payload: variable]
|
||
// [message_crc: 4 bytes]
|
||
type bedrockEventStreamDecoder struct {
|
||
reader *bufio.Reader
|
||
}
|
||
|
||
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
|
||
return &bedrockEventStreamDecoder{
|
||
reader: bufio.NewReaderSize(r, 64*1024),
|
||
}
|
||
}
|
||
|
||
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
|
||
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
|
||
for {
|
||
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
|
||
prelude := make([]byte, 12)
|
||
if _, err := io.ReadFull(d.reader, prelude); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE)
|
||
preludeCRC := bedrockReadUint32(prelude[8:12])
|
||
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
|
||
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
|
||
}
|
||
|
||
totalLength := bedrockReadUint32(prelude[0:4])
|
||
headersLength := bedrockReadUint32(prelude[4:8])
|
||
|
||
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
|
||
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
|
||
}
|
||
|
||
// 读取 headers + payload + message_crc
|
||
remaining := int(totalLength) - 12
|
||
if remaining <= 0 {
|
||
continue
|
||
}
|
||
data := make([]byte, remaining)
|
||
if _, err := io.ReadFull(d.reader, data); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证 message CRC(覆盖 prelude + headers + payload)
|
||
messageCRC := bedrockReadUint32(data[len(data)-4:])
|
||
h := crc32.New(crc32IEEETable)
|
||
h.Write(prelude)
|
||
h.Write(data[:len(data)-4])
|
||
if h.Sum32() != messageCRC {
|
||
return nil, fmt.Errorf("eventstream message CRC mismatch")
|
||
}
|
||
|
||
// 解析 headers
|
||
headers := data[:headersLength]
|
||
payload := data[headersLength : len(data)-4] // 去掉 message_crc
|
||
|
||
// 从 headers 中提取 :event-type
|
||
eventType := extractEventStreamHeaderValue(headers, ":event-type")
|
||
|
||
// 只处理 chunk 事件
|
||
if eventType == "chunk" {
|
||
// payload 是完整的 JSON,包含 bytes 字段
|
||
return payload, nil
|
||
}
|
||
|
||
// 检查异常事件
|
||
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
|
||
if exceptionType != "" {
|
||
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
|
||
}
|
||
|
||
messageType := extractEventStreamHeaderValue(headers, ":message-type")
|
||
if messageType == "exception" || messageType == "error" {
|
||
return nil, fmt.Errorf("bedrock error: %s", string(payload))
|
||
}
|
||
|
||
// 跳过其他事件类型(如 initial-response)
|
||
}
|
||
}
|
||
|
||
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
|
||
// EventStream header 格式:
|
||
//
|
||
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
|
||
//
|
||
// value_type = 7 表示 string 类型,前 2 bytes 为长度
|
||
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
|
||
pos := 0
|
||
for pos < len(headers) {
|
||
if pos >= len(headers) {
|
||
break
|
||
}
|
||
nameLen := int(headers[pos])
|
||
pos++
|
||
if pos+nameLen > len(headers) {
|
||
break
|
||
}
|
||
name := string(headers[pos : pos+nameLen])
|
||
pos += nameLen
|
||
|
||
if pos >= len(headers) {
|
||
break
|
||
}
|
||
valueType := headers[pos]
|
||
pos++
|
||
|
||
switch valueType {
|
||
case 7: // string
|
||
if pos+2 > len(headers) {
|
||
return ""
|
||
}
|
||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||
pos += 2
|
||
if pos+valueLen > len(headers) {
|
||
return ""
|
||
}
|
||
value := string(headers[pos : pos+valueLen])
|
||
pos += valueLen
|
||
if name == targetName {
|
||
return value
|
||
}
|
||
case 0: // bool true
|
||
if name == targetName {
|
||
return "true"
|
||
}
|
||
case 1: // bool false
|
||
if name == targetName {
|
||
return "false"
|
||
}
|
||
case 2: // byte
|
||
pos++
|
||
if name == targetName {
|
||
return ""
|
||
}
|
||
case 3: // short
|
||
pos += 2
|
||
if name == targetName {
|
||
return ""
|
||
}
|
||
case 4: // int
|
||
pos += 4
|
||
if name == targetName {
|
||
return ""
|
||
}
|
||
case 5: // long
|
||
pos += 8
|
||
if name == targetName {
|
||
return ""
|
||
}
|
||
case 6: // bytes
|
||
if pos+2 > len(headers) {
|
||
return ""
|
||
}
|
||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||
pos += 2 + valueLen
|
||
case 8: // timestamp
|
||
pos += 8
|
||
case 9: // uuid
|
||
pos += 16
|
||
default:
|
||
return "" // 未知类型,无法继续解析
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
|
||
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||
|
||
func bedrockReadUint32(b []byte) uint32 {
|
||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||
}
|
||
|
||
func bedrockReadUint16(b []byte) uint16 {
|
||
return uint16(b[0])<<8 | uint16(b[1])
|
||
}
|