使用生成式 AI 自动创建社交媒体图像

发布时间:2024 年 5 月 29 日
EventBridge
Step Functions
Amazon Bedrock
Go
Comprehend
教程
亚马逊云科技
Olawale Olaleye
亚马逊云科技使用经验
100 - 初级
前提条件

海外区域: 注册 / 登录 亚马逊云科技

完成所需时间
10 分钟
上次更新时间
2024 年 5 月 29 日
相关产品

前端

首先,需要通过前端来实现应用程序的可视化,同时打造一个用户友好的界面。以下是前端的构建方法:

  • 确保类型安全和更高的代码质量。
  • 处理 API 请求。
  • 一套组件库,非常适用于尝试新内容,而 Antd 看起来相当不错。

集成亚马逊云科技

亚马逊云科技部分是本项目的重点,创建的应用程序具备稳健性和可维护性,采用以下设置:

  • AWS SAM:决定使用 SAM 作为编写基础设施即代码的工具。
  • Amazon API-Gateway:作为后端的入口点,管理和路由输入的请求。在与 Amazon EventBridge 集成后,用于解耦应用程序组件。
  • Amazon EventBridge:从 Amazon API Gateway 接收事件,并转发到应用程序的适当部分(即 AWS Step Functions)。
  • AWS Step Functions:从 Amazon EventBridge 接收事件,并编排工作流程来处理前端的请求,让每一步操作一目了然。
  • AWS Lambda(使用 Go 语言):作为 AWS Step Functions 工作流程中的独立步骤,来处理准备提示词和生成图像等特定任务。
  • Amazon Comprehend:在准备生成图像的提示词时,分析帖子中的情感并提供有关上下文的理解,从而提高生成图像的质量。
  • Amazon Bedrock 和 DALL-E:在与 AWS Lambda 集成后,负责根据已分析的社交媒体帖子创建图像。

使用 DALL-E 来创建图像,因为这是生成图像的主流选择。但在开发过程中,在 DALL-E 和 Amazon Bedrock 中添加了一些功能,让用户在生成图像时可以在 DALL-E 和 Amazon Bedrock 之间进行选择。

最终计划的可视化图:

示意图

前提条件和设置

  • 前端:Node.js 和 TypeScript
    • 组件库:Ant Design (Antd) npm i antd
    • 其他库:React Query 和 Axios npm i react-query axios
  • 亚马逊云科技:AWS Sam CLI 安装说明
    • 语言:Go 文档
  • DALL-E:需要创建 OpenAI 开发者账户来生成并获取 API 密钥。

如何构建 – 免责声明

本教程的初衷只是为了进行概念验证,而非投入生产使用,并且构建过程中没有考虑安全性因素。因此,仅限测试使用,不得部署于任何生产环境。如确需投入生产使用,建议添加身份验证等保障措施。

如何构建:后端

首先,在新建文件夹中使用 sam init 命令初始化一个新的 AWS SAM 项目,并依次选择 “Hello World” 示例模板和 Go 语言 (provided.al2023)。

接着,使用 template.yaml 文件初始化该项目。完成这一步后,下一步就是添加必要的基础设施。

为此,清除模板文件中的所有内容,并输入以下代码:

AWSTemplateFormatVersion: "2010-09-09"
Transform: AWS::Serverless-2016-10-31

Resources:

Outputs:

如何构建:后端 – Amazon EventBridge

在 Amazon EventBridge 中定义 Amazon EventBridge EventBus 来处理来自 Amazon API Gateway的事件,并授予 Amazon EventBridge 权限来执行后续定义的 AWS Step Functions。同时,定义 Amazon EventBridge 规则来指定事件的路由目标。
PostToImageEventBus:
  Type: AWS::Events::EventBus
  Properties:
    Name: PostToImageEventBus

StateMachineEventBridgeRule:
  Type: AWS::Events::Rule
  Properties:
    EventBusName: !Ref PostToImageEventBus
    EventPattern:
      source:
        - "api-gateway"
      detail-type:
        - "PreparePrompt"
    Targets:
      - Arn: !GetAtt PostToImageStateMachine.Arn
        Id: "PostToImageStateMachineTarget"
        RoleArn: !GetAtt EventBridgeExecutionRole.Arn
    State: "ENABLED"

EventBridgeExecutionRole:
  Type: "AWS::IAM::Role"
  Properties:
    AssumeRolePolicyDocument:
      Version: "2012-10-17"
      Statement:
        - Effect: Allow
          Principal:
            Service: "events.amazonaws.com"
          Action:
            - "sts:AssumeRole"
    Policies:
      - PolicyName: "EventBridgeStepFunctionsExecutionPolicy"
        PolicyDocument:
          Version: "2012-10-17"
          Statement:
            - Effect: Allow
              Action:
                - "states:StartExecution"
              Resource: !GetAtt PostToImageStateMachine.Arn

如何构建:后端 – API

要想前端能够通过入口点与亚马逊云科技进行通信,就必须添加 API。在研究如何集成Amazon EventBridge 和 Amazon API Gateway 的过程中,基于这份示例,定义了 API,并将其存储在新创建的 api.yaml 文件中。该 API 共包含两个路由:
  • 一个用于启动图像创建流程
  • 一个用于在图像创建完毕后,轮询后端以获取生成的图像。

在 template.yaml 文件中,添加以下代码来定义 API 和所需角色。

PostToImageApi:
  Type: AWS::Serverless::HttpApi
  Properties:
    DefinitionBody:
      "Fn::Transform":
        Name: "AWS::Include"
        Parameters:
          Location: "./api.yaml"

# Permission to put events to eventBridge
HttpApiEvenbridgeRole:
  Type: "AWS::IAM::Role"
  Properties:
    AssumeRolePolicyDocument:
      Version: "2012-10-17"
      Statement:
        - Effect: "Allow"
          Principal:
            Service: "apigateway.amazonaws.com"
          Action:
            - "sts:AssumeRole"
    Policies:
      - PolicyName: ApiDirectWriteEventBridge
        PolicyDocument:
          Version: "2012-10-17"
          Statement:
            Action:
              - events:PutEvents
            Effect: Allow
            Resource:
              - !GetAtt PostToImageEventBus.Arn

# Permission to invoke the GetImage Lambda
HttpApiLambdaRole:
  Type: "AWS::IAM::Role"
  Properties:
    AssumeRolePolicyDocument:
      Version: "2012-10-17"
      Statement:
        - Effect: "Allow"
          Principal:
            Service: "apigateway.amazonaws.com"
          Action:
            - "sts:AssumeRole"
    Policies:
      - PolicyName: "LambdaExecutionPolicy"
        PolicyDocument:
          Version: "2012-10-17"
          Statement:
            - Effect: Allow
              Action:
                - "lambda:InvokeFunction"
              Resource: !GetAtt GetImageFunction.Arn

API 的定义如下:

openapi: "3.0.1"
info:
  title: "HTTP API"
paths:
  /get-image:
    get:
      parameters:
        - name: s3Key
          in: query
          description: "Path to image in S3 bucket"
          required: true
          schema:
            type: string
      responses:
        200:
          description: "Successful response"
          content:
            application/json:
              schema:
                type: object
                properties:
                  url:
                    type: string
                    description: "URL of the retrieved image"
      x-amazon-apigateway-integration:
        type: "aws_proxy"
        httpMethod: "POST"
        uri:
          Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${GetImageFunction.Arn}/invocations
        credentials:
          Fn::GetAtt: [HttpApiLambdaRole, Arn]
        payloadFormatVersion: "1.0"
        passthroughBehavior: "when_no_match"
  /generate-image:
    post:
      responses:
        default:
          description: "API to EventBridge"
      x-amazon-apigateway-integration:
        integrationSubtype: "EventBridge-PutEvents"
        credentials:
          Fn::GetAtt: [HttpApiEvenbridgeRole, Arn]
        requestParameters:
          Detail: "$request.body"
          DetailType: PreparePrompt
          Source: api-gateway
          EventBusName:
            Fn::GetAtt: [PostToImageEventBus, Name]
        payloadFormatVersion: "1.0"
        type: "aws_proxy"
        connectionType: "INTERNET"
x-amazon-apigateway-importexport-version: "1.0"
x-amazon-apigateway-cors:
  allowOrigins:
    - "*"
  allowHeaders:
    - "*"
  allowMethods:
    - "PUT"
    - "POST"
    - "DELETE"
    - "HEAD"
    - "GET"

如何构建:后端 – Amazon S3

图像在创建完毕后,需要进行存储。为此,定义了一个 Amazon S3 存储桶,存储桶策略如下:

ImageUploadBucket:
  Type: AWS::S3::Bucket
  Properties:
    BucketName: !Sub "${AWS::StackName}-image-upload-bucket"
    PublicAccessBlockConfiguration:
      BlockPublicAcls: false
      BlockPublicPolicy: false
      IgnorePublicAcls: false
      RestrictPublicBuckets: false
    CorsConfiguration:
      CorsRules:
        - AllowedHeaders:
            - "*"
          AllowedMethods:
            - PUT
            - POST
            - DELETE
            - HEAD
            - GET
          AllowedOrigins:
            - "*"
          ExposedHeaders: []

  # This makes the bucket available for everyone to do PUT and GET
  ImageUploadBucketPolicy:
    Type: AWS::S3::BucketPolicy
    Properties:
      Bucket: !Ref ImageUploadBucket
      PolicyDocument:
        Version: "2012-10-17"
        Statement:
          - Effect: Allow
            Principal: "*"
            Action:
              - s3:PutObject
              - s3:GetObject
            Resource: !Sub "arn:aws:s3:::${AWS::StackName}-image-upload-bucket/*"

如何构建:后端 – AWS Step Functions

为了编排 Lambda 函数,使其井然有序地依次执行,创建一个 AWS Step Functions,并包含了两条可选路径:一条用于创建 DALL-E 图像,一条用于创建 Amazon Bedrock 图像。AWS Step Functions 和角色的定义如下:

PostToImageStateMachine:
  Type: "AWS::Serverless::StateMachine"
  Properties:
    Definition:
      StartAt: PromptPreparation
      States:
        PromptPreparation:
          Type: Task
          Resource: !GetAtt PromptPreparationAPIFunction.Arn
          Next: CheckBedRock
        # Choice to see which route to take
        CheckBedRock:
          Type: Choice
          Choices:
            - Variable: "$.bedrock"
              BooleanEquals: true
              Next: BedrockImageGeneration
          Default: DalleImageGeneration
        BedrockImageGeneration:
          Type: Task
          Resource: !GetAtt BedrockImageGenerationFunction.Arn
          End: true
        DalleImageGeneration:
          Type: Task
          Resource: !GetAtt DalleImageGenerationFunction.Arn
          Next: ImageUpload
        ImageUpload:
          Type: Task
          Resource: !GetAtt ImageUploadFunction.Arn
          End: true
    Role: !GetAtt PostToImageStateMachineRole.Arn

# Gives permission to the step function to invoke the required lambda functions
# And permission to handle the actual step function
PostToImageStateMachineRole:
  Type: "AWS::IAM::Role"
  Properties:
    AssumeRolePolicyDocument:
      Version: "2012-10-17"
      Statement:
        - Effect: Allow
          Principal:
            Service:
              - states.amazonaws.com
          Action:
            - "sts:AssumeRole"
    Policies:
      - PolicyName: "StateMachineExecutionPolicy"
        PolicyDocument:
          Version: "2012-10-17"
          Statement:
            - Effect: Allow
              Action:
                - "lambda:InvokeFunction"
              Resource:
                - !GetAtt PromptPreparationAPIFunction.Arn
                - !GetAtt DalleImageGenerationFunction.Arn
                - !GetAtt ImageUploadFunction.Arn
                - !GetAtt BedrockImageGenerationFunction.Arn
            - Effect: Allow
              Action:
                - "states:StartExecution"
                - "states:DescribeExecution"
                - "states:StopExecution"
              Resource: "*"

如何构建:后端 – 使用 AWS SAM 模板定义 Lambda 函数

在创建完 AWS Step Functions 后,接下来需要定义由其触发的 Lambda 函数。

  • 由于需要权限才能检测文本中的情感和关键词,因此提示准备函数还需获取 Amazon Comprehend 的权限。
PromptPreparationAPIFunction:
  Type: AWS::Serverless::Function
  Properties:
    Handler: bootstrap
    Runtime: provided.al2023
    CodeUri: prepare-prompt/
    Timeout: 30
    Policies:
      - AWSLambdaBasicExecutionRole
      - Version: "2012-10-17"
        Statement:
          - Effect: "Allow"
            Action:
              - comprehend:DetectSentiment
              - comprehend:DetectKeyPhrases
            Resource: "*"

由于 DALL-E 图像生成函数向 OpenAI 请求的不是亚马逊云科技服务,因此无需特殊权限,但由于 Lambda 函数需要授权向 DALL-E 发送的请求,因此该函数仍需获取 API 密钥的环境变量。

DalleImageGenerationFunction:
  Type: AWS::Serverless::Function
  Properties:
    Handler: bootstrap
    Runtime: provided.al2023
    CodeUri: generate-image-dalle/
    Timeout: 30
    Environment:
      Variables:
        OPENAI_API_KEY: !Ref OpenAIApiKey
    Policies:
      - AWSLambdaBasicExecutionRole

由于 Bedrock 图像生成函数需要调用 Amazon Bedrock,因此需要获取 Amazon Bedrock 的权限。

BedrockImageGenerationFunction:
  Type: AWS::Serverless::Function
  Properties:
    Handler: bootstrap
    Runtime: provided.al2023
    CodeUri: generate-image-bedrock/
    Timeout: 30
    Environment:
      Variables:
        BUCKET_NAME: !Ref ImageUploadBucket
    Policies:
      - AWSLambdaBasicExecutionRole
      - Version: "2012-10-17"
        Statement:
          - Effect: Allow
            Action:
              - "bedrock:InvokeModel"
            Resource: "*"

由于 Amazon S3 存储桶已设置为公开访问,因此图像上传函数无需特殊权限。

ImageUploadFunction:
  Type: AWS::Serverless::Function
  Properties:
    Handler: bootstrap
    Runtime: provided.al2023
    CodeUri: upload-image/
    Timeout: 30
    Environment:
      Variables:
        BUCKET_NAME: !Ref ImageUploadBucket
        REGION: !Ref AWS::Region
    Policies:
      - AWSLambdaBasicExecutionRole

最后一个函数的任务是将图像集成到 Amazon API Gateway。由于 Amazon S3 存储桶已设置为公开访问,因此该函数无需特殊权限。但为了知道应将图像上传到哪个存储桶,该函数仍需获取对应的环境变量。

GetImageFunction:
  Type: AWS::Serverless::Function
  Properties:
    Handler: bootstrap
    Runtime: provided.al2023
    CodeUri: get-image/
    Timeout: 30
    Environment:
      Variables:
        BUCKET_NAME: !Ref ImageUploadBucket
    Policies:
      - AWSLambdaBasicExecutionRole

如何构建:后端 – 使用 Go 语言实现 Lambda 函数

在定义完所有的基础设施后,接下来需要在 Lambda 函数中实现逻辑。有关所有 Lambda 函数的完整文件夹目录结果,请参阅 Github 代码库 Repository。在这里,您还可以看到使用 sam build 构建项目所需的文件示例。

帮助大家详细了解实现逻辑,所有函数均添加了注释。

准备提示函数

该 Lambda 函数负责使用 Amazon Comprehend 来对来自前端的帖子(文本)进行情感分析。当收到帖子和 Amazon S3 密钥时,该函数会被激活,并执行以下操作:

  • 初始化 Amazon Comprehend 客户端以分析输入文本的情感。
  • 使用 Amazon Comprehend 来检测文本情感。
  • 确定主要情感并转换为对应情绪。
  • 提取文本中的关键短语并转换为字符串。
  • 准备提示并发送给下一个函数。
package main

import (
	"context"
	"fmt"
	"log"
  "strings"

	"github.com/aws/aws-lambda-go/lambda"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/comprehend"
)

type InputEvent struct {
	Detail struct {
		Text  string `json:"text"`
		S3Key string `json:"s3Key"`
		Bedrock bool `json:"bedrock"`
	} `json:"detail"`
}

type OutputEvent struct {
	Prompt string `json:"prompt"`
	S3Key  string `json:"s3Key"`
	Bedrock bool `json:"bedrock"`
}

// function to get the top sentiment and convert it for the prompt
func getTopSentiment(sentimentScores *comprehend.SentimentScore) string {
	sentiments := map[string]float64{
		"Positive": *sentimentScores.Positive,
		"Negative": *sentimentScores.Negative,
		"Neutral":  *sentimentScores.Neutral,
		"Mixed":    *sentimentScores.Mixed,
	}

	var topSentiment string
	var maxScore float64

	for sentiment, score := range sentiments {
		if score > maxScore {
			maxScore = score
			topSentiment = sentiment
		}
	}

	switch topSentiment {
	case "Positive":
		topSentiment = "happy"
	case "Negative":
		topSentiment = "sad"
	case "Neutral":
		fallthrough
	case "Mixed":
		topSentiment = "neutral"
	}

	return topSentiment
}

func handler(ctx context.Context, event InputEvent) (OutputEvent, error) {
	text := event.Detail.Text
  // creating a new session for comprehend and forcing eu west 1 as regions since its not available in the eu north 1 region
	svc := comprehend.New(session.Must(session.NewSession(&aws.Config{
		Region: aws.String("eu-west-1"),
	})))

  // Preparing the parameters to be sent to comprehend
	sentimentParams := &comprehend.DetectSentimentInput{
		Text:         aws.String(text),
		LanguageCode: aws.String("en"),
	}

  keyPhraseParams := &comprehend.DetectKeyPhrasesInput{
  Text:         aws.String(text),
  LanguageCode: aws.String("en"),
	}

  // Detecting the sentiment of the provided text
	sentimentResult, err := svc.DetectSentiment(sentimentParams)
	if err != nil {
		return OutputEvent{}, fmt.Errorf("failed to detect sentiment: %w", err)
	}

	keyPhraseResult, err := svc.DetectKeyPhrases(keyPhraseParams)
	if err != nil {
		return OutputEvent{}, fmt.Errorf("failed to process text: %w", err)
	}

	var keyPhrases []string
  // Getting the top sentiment
	topSentiment := getTopSentiment(sentimentResult.SentimentScore)
  // start the phrases array with the emotion
  keyPhrases = append(keyPhrases, topSentiment)
  // append all phrases to string array
	for _, phrase := range keyPhraseResult.KeyPhrases {
    keyPhrases = append(keyPhrases, *phrase.Text)
	}
  // create a string from the key words
	summary := strings.Join(keyPhrases, " ")

  // Creating the prompt to be used for image generation
	prompt := fmt.Sprintf("Generate a image based on the following key words: %s",  summary)

  // Creating the output and returning to the next function
	outputEvent := OutputEvent{Prompt: prompt, S3Key: event.Detail.S3Key, Bedrock: event.Detail.Bedrock}

	return outputEvent, nil
}

func main() {
	lambda.Start(handler)
}

生成图像 (DALL-E) 函数

该 Lambda 函数负责处理由上一个函数发送的提示,并使用 DALL-E API 生成图像。当收到提示和 Amazon S3 密钥时,该函数会被激活,并执行以下操作:

  • 从环境变量中获取 OpenAI API 密钥。
  • 利用提供的提示创建 DALL-E API 的请求主体。
  • 向 DALL-E API 发送请求以生成图像。
  • 读取并解析来自 DALL-E API 的响应。
  • 从响应中提取图像 URL。
  • 将图像 URL 和 Amazon S3 密钥发送给下一个函数。
package main

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"

	"github.com/aws/aws-lambda-go/lambda"
)

type InputEvent struct {
	Prompt string `json:"prompt"`
	S3Key  string `json:"s3Key"`
}

type OutputEvent struct {
	Url   string `json:"url"`
	S3Key string `json:"s3Key"`
}

// Type for how the request to DALL-E should be structured
type DalleRequest struct {
	Model  string `json:"model"`
	Prompt string `json:"prompt"`
	Size   string `json:"size"`
	N      int    `json:"n"`
}

// Type for how the request response is structured
type ResponseObject struct {
	Data []struct {
		Url string `json:"url"`
	} `json:"data"`
}

func handler(ctx context.Context, event InputEvent) (OutputEvent, error) {
  // Grabbing our api-key from the environment variable
	openaiApiKey := os.Getenv("OPENAI_API_KEY")
	if openaiApiKey == "" {
		log.Println("Error: OPENAI_API_KEY environment variable is not set")
		return OutputEvent{}, fmt.Errorf("missing OPENAI_API_KEY")
	}

  // Preparing the DALL-E request
	requestBody := DalleRequest{
		Model:  "dall-e-3",
		Prompt: event.Prompt,
		Size:   "1024x1024",
		N:      1,
	}

  // preparing the payload for the request
	payload, err := json.Marshal(requestBody)
	if err != nil {
		log.Println("Error marshalling data:", err)
		return OutputEvent{}, err
	}

  // creating request
	client := &http.Client{}
	req, err := http.NewRequest("POST", "https://api.openai.com/v1/images/generations", bytes.NewBuffer(payload))
	if err != nil {
		log.Println("Error creating request:", err)
		return OutputEvent{}, err
	}

  // Setting headers for the request
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", openaiApiKey))

  // sending request to openAI
	resp, err := client.Do(req)
	if err != nil {
		log.Println("Error sending request:", err)
		return OutputEvent{}, err
	}
	defer resp.Body.Close()

  // Reading all data from the response
	body, err := io.ReadAll(resp.Body)
	if err != nil {
		log.Println("Error reading response body:", err)
		return OutputEvent{}, err
	}
  // If the response is not of status OK an error is returned
	if resp.StatusCode != http.StatusOK {
		log.Printf("Non-OK HTTP status: %s\nResponse body: %s\n", resp.Status, string(body))
		return OutputEvent{}, fmt.Errorf("non-OK HTTP status: %s", resp.Status)
	}

  // Un-marshalling the response
	var respData ResponseObject
	err = json.Unmarshal(body, &respData)
	if err != nil {
		log.Println("Error unmarshalling response:", err)
		return OutputEvent{}, err
	}

	if len(respData.Data) == 0 {
		log.Println("No data received in the response")
		return OutputEvent{}, fmt.Errorf("no data in the response")
	}

  // Extracting the URL for the image
	imageURL := respData.Data[0].Url

  // Creating the output and returning to the next function
	outputEvent := OutputEvent{Url: imageURL, S3Key: event.S3Key}

	return outputEvent, nil
}

func main() {
	lambda.Start(handler)
}

生成图像 (Bedrock) 函数

该 Lambda 函数负责处理文本提示,以及使用 Amazon Titan 模型生成图像并存储到 Amazon S3 存储桶中。当收到文本提示和 Amazon S3 密钥时,该函数会被激活,并执行以下操作:

  • 从环境变量中读取 Amazon S3 存储桶名称并加载亚马逊云科技配置。
  • 准备负载,将文本提示封装成图像生成任务所需的参数,并指定图像尺寸等。
  • 使用准备好的负载调用 Amazon Titan 图像生成模型。
  • 使用 Amazon Titan 模型处理提示并返回 base64 编码的图像。
  • 将 base64 编码的图像解码为字节数组。
  • 将字节数组上传到 Amazon S3。
package main

import (
	"bytes"
	"context"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"os"

	"github.com/aws/aws-lambda-go/lambda"
	"github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
)

type InputEvent struct {
	Prompt string `json:"prompt"`
	S3Key  string `json:"s3Key"`
}

type OutputEvent struct {
	Url   string `json:"url"`
}

// Titan specific types for requests towards bedrock
type TitanInputTextToImageInput struct {
	TaskType              string                                                `json:"taskType"`
	ImageGenerationConfig TitanInputTextToImageConfig `json:"imageGenerationConfig"`
	TextToImageParams     TitanInputTextToImageParams     `json:"textToImageParams"`
}

type TitanInputTextToImageParams struct {
	Text         string `json:"text"`
	NegativeText string `json:"negativeText,omitempty"`
}

type TitanInputTextToImageConfig struct {
	NumberOfImages int     `json:"numberOfImages,omitempty"`
	Height         int     `json:"height,omitempty"`
	Width          int     `json:"width,omitempty"`
	Scale          float64 `json:"cfgScale,omitempty"`
	Seed           int     `json:"seed,omitempty"`
}

type TitanInputTextToImageOutput struct {
	Images []string `json:"images"`
	Error  string   `json:"error"`
}

// Function to decode base64 image
func decodeImage(base64Image string) ([]byte, error) {
	decoded, err := base64.StdEncoding.DecodeString(base64Image)
	if err != nil {
		return nil, err
	}
	return decoded, nil
}

func handler(ctx context.Context, event InputEvent)  error {
	// Grabbing the name of the bucket from the environment variable
	bucketName := os.Getenv("BUCKET_NAME")
	if bucketName == "" {
		log.Println("Error: BUCKET_NAME environment variable is not set")
		return fmt.Errorf("missing BUCKET_NAME")
	}

	// Preparing config for the runtime and forcing us east 1 since its not available in eu north 1
	cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
	if err != nil {
		return err
	}

	// Creating the runtime for bedrock
	runtime := bedrockruntime.NewFromConfig(cfg)

	// Preparing the request payload for bedrock
	payload := TitanInputTextToImageInput{
		TaskType: "TEXT_IMAGE",
		TextToImageParams: TitanInputTextToImageParams{
			Text: event.Prompt,
		},
		ImageGenerationConfig: TitanInputTextToImageConfig{
			NumberOfImages: 1,
			Scale: 8.0,
			Height: 1024.0,
			Width: 1024.0,

		},
	}

	payloadString, err := json.Marshal(payload)
	if err != nil {
		return fmt.Errorf("unable to marshal body: %v", err)
	}

	accept := "*/*"
	contentType := "application/json"
	model := "amazon.titan-image-generator-v1"

	// Sending request to bedrock
	resp, err := runtime.InvokeModel(context.TODO(), &bedrockruntime.
	InvokeModelInput{
		Accept:      &accept,
		ModelId:     &model,
		ContentType: &contentType,
		Body:        payloadString,
	})

	if err != nil {
		return fmt.Errorf("error from Bedrock, %v", err)
	}

	var output TitanInputTextToImageOutput

	err = json.Unmarshal(resp.Body, &output)
	if err != nil {
		return fmt.Errorf("unable to unmarshal response from Bedrock: %v", err)
	}
  	// Decoding base64 to be able to upload image to S3
	decoded, err := decodeImage(output.Images[0])
	if err != nil {
		return fmt.Errorf("unable to decode image: %v", err)
	}

  	// Creating a session for S3
	sesh := session.Must(session.NewSession())

	s3Client := s3.New(sesh)

	objectKey := event.S3Key

	// Uploading image to S3
	_, err = s3Client.PutObject(&s3.PutObjectInput{
		Bucket:      aws.String(bucketName),
		Key:         aws.String(objectKey),
		Body:        bytes.NewReader(decoded),
		ContentType: aws.String(http.DetectContentType(decoded)),
	})
	if err != nil {
		log.Println("Error uploading image to S3:", err)
		return err
	}

	log.Println("Successfully uploaded image to S3:", objectKey)

	return  nil
}

func main() {
	lambda.Start(handler)
}

上传图像函数

该 Lambda 函数负责处理来自上一个函数的 URL,以及将下载的图像上传到 Amazon S3 存储桶。当收到 URL 和 Amazon S3 密钥时,该函数会被激活,并执行以下操作:

  • 从环境变量中获取存储桶名称和区域。
  • 根据提供的 URL 下载图像。
  • 从 HTTP 响应中读取图像数据。
  • 初始化 Amazon S3 客户端。
  • 使用提供的 Amazon S3 密钥将图像数据上传到指定 Amazon S3 存储桶。
package main

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"

	"github.com/aws/aws-lambda-go/lambda"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
)

type Event struct {
	Url   string `json:"url"`
	S3Key string `json:"s3Key"`
}

func handler(ctx context.Context, event Event)  error {
	// Grabbing the name of the bucket from the environment variable
	bucketName := os.Getenv("BUCKET_NAME")
	if bucketName == "" {
		log.Println("Error: BUCKET_NAME environment variable is not set")
		return fmt.Errorf("missing BUCKET_NAME")
	}

	region := os.Getenv("REGION")
	if region == "" {
		log.Println("Error: REGION environment variable is not set")
		return fmt.Errorf("missing REGION")
	}

	imageResp, err := http.Get(event.Url)
	if err != nil {
		log.Println("Error downloading image:", err)
		return err
	}
	defer imageResp.Body.Close()

  	// Downloading image from provided URL
	if imageResp.StatusCode != http.StatusOK {
		log.Printf("Error: received non-OK HTTP status when downloading image: %s\n", imageResp.Status)
		return fmt.Errorf("non-OK HTTP status when downloading image: %s", imageResp.Status)
	}

	imageData, err := io.ReadAll(imageResp.Body)
	if err != nil {
		log.Println("Error reading image data:", err)
		return err
	}
	// Creating a session for S3
	sesh := session.Must(session.NewSession())

	s3Client := s3.New(sesh)

	objectKey := event.S3Key
 	 // Uploading image to S3
	_, err = s3Client.PutObject(&s3.PutObjectInput{
		Bucket:      aws.String(bucketName),
		Key:         aws.String(objectKey),
		Body:        bytes.NewReader(imageData),
		ContentType: aws.String(http.DetectContentType(imageData)),
	})
	if err != nil {
		log.Println("Error uploading image to S3:", err)
		return err
	}

	log.Println("Successfully uploaded image to S3:", objectKey)

	return nil
}

func main() {
	lambda.Start(handler)
}

获取图像函数

该 Lambda 函数负责处理来自 Amazon API Gateway 的请求,来确认 Amazon S3 存储桶中是否存在特定对象,以及返回该对象 URL。当收到请求时,该函数会被激活,并执行以下操作:

  • 从环境变量中获取存储桶名称。
  • 从请求中提取 s3Key 查询参数。
  • 初始化 Amazon S3 客户端。
  • 利用 s3Key 确认 Amazon S3 存储桶中是否存在特定对象。
  • 如果存在,则返回包含该对象 URL 的响应。
  • 如果不存在或发生错误,则返回错误信息。
package main

import (
	"context"
	"fmt"
	"log"
	"os"

	"github.com/aws/aws-lambda-go/events"
	"github.com/aws/aws-lambda-go/lambda"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/s3"
)

func handler(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
  //Grabbing the name of the bucket from the environment variable
	bucketName := os.Getenv("BUCKET_NAME")
	if bucketName == "" {
		log.Println("Error: BUCKET_NAME environment variable is not set")
		return events.APIGatewayProxyResponse{
			StatusCode: 400,
			Body:      "missing BUCKET_NAME",
		},nil
	}

  // Extracting the S3 key
	queryParams := request.QueryStringParameters
	s3Key, exists := queryParams["s3Key"]
	if !exists {
		return events.APIGatewayProxyResponse{
			StatusCode: 400,
			Body:       "Missing query parameter 's3Key'",
		},nil
	}

  // Creating a connection to S3
	sesh := session.Must(session.NewSession())

	s3Client := s3.New(sesh)

	input := &s3.HeadObjectInput{
		Bucket: aws.String(bucketName),
		Key:    aws.String(s3Key),
	}

  // Checking if image exist
	_, err := s3Client.HeadObject(input)
	if err != nil {
		log.Printf(`{"url": "https://%s.s3.eu-north-1.amazonaws.com/%s"}`, bucketName, s3Key)
		log.Println("param:", s3Key)
		log.Println("Error checking if object exists:", err)
    // If it does not exist return 404
		return events.APIGatewayProxyResponse{
			StatusCode: 404,
			Body:       "Object does not exist",
		}, nil
	}

  // If it does exist return url
	return events.APIGatewayProxyResponse{
		StatusCode: 200,
		Body:       fmt.Sprintf(`{"url": "https://%s.s3.eu-north-1.amazonaws.com/%s"}`, bucketName, s3Key),
	}, nil
}

func main() {
	lambda.Start(handler)
}

如何构建:前端

为了保证前端简洁直观,构建请求时使用了一套预构建的组件库和一些强大的开箱即用型工具。

构建前端时,首先使用 npx create-react-app my-app --template typescript 命令在新文件夹中初始化一个新的基于 TypeScript 的 React 项目。

接着,依次创建文件夹 components 和子文件夹 mainContent,并将 index.tsx 文件保存在子文件夹中。由于前端并不是本项目的重点,将所有内容放在同一个组件中。最终的组件代码如下:

import React from "react";
import { Button, Flex, Layout, Spin, Image, Switch } from "antd";
import TextArea from "antd/es/input/TextArea";
import { Typography } from "antd";
import { useMutation, useQuery } from "@tanstack/react-query";
import { getImage, sendPost } from "../../utils/api";

const { Paragraph, Text } = Typography;

const { Header, Content } = Layout;

const headerStyle: React.CSSProperties = {
  textAlign: "center",
  color: "#fff",
  height: 64,
  paddingInline: 48,
  lineHeight: "64px",
  backgroundColor: "#00415a",
};

const contentStyle: React.CSSProperties = {
  textAlign: "center",
  minHeight: 120,
  lineHeight: "120px",
  color: "#fff",
  backgroundColor: "#00719c",
};

const layoutStyle = {
  borderRadius: 8,
  overflow: "hidden",
  width: "calc(50% - 8px)",
  maxWidth: "calc(50% - 8px)",
  marginTop: "10vh",
  height: "100%",
};

const textAreaStyle = {
  width: "80%",
  height: "80%",
};

const buttonStyle = {
  width: "80%",
  height: "80%",
  marginBottom: 16,
  backgroundColor: "#009bd6",
};
const paragraphStyle = {
  margin: 10,
};

const textStyle = {
  color: "white",
};

export default function MainContent() {
  const id = "testing";
  const [post, setPost] = React.useState("");
  const [image, setImage] = React.useState("");
  const [refetch, setRefetch] = React.useState(true);
  const [loading, setLoading] = React.useState(false);
  const [bedrock, setBedrock] = React.useState(false);

  // A function to trigger the generation of an image
  const { mutate } = useMutation({
    mutationFn: () => {
      setLoading(true);
      return sendPost("/generate-image", {
        text: post,
        s3Key: `image/${id}`,
        bedrock: bedrock,
      });
    },
  });
  // A function for polling the /get-image endpoint. When status 200 is returned polling stops and image is displayed
  const { data } = useQuery({
    queryKey: ["image"],
    refetchInterval: 3000,
    refetchIntervalInBackground: true,
    enabled: refetch,
    queryFn: async () => {
      const imageData = await getImage("/get-image", { s3Key: `image/${id}` });
      if (imageData.status === 200) {
        setRefetch(false);
        setImage(imageData.data);
        setLoading(false);
      }
      return imageData.data;
    },
  });

  const handlePostChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
    setPost(event.target.value);
  };

  const handleBedrockChange = (checked: boolean) => {
    setBedrock(checked);
  };

  return (
    <Spin spinning={loading}>
      <Flex justify="center">
        <Layout style={layoutStyle}>
          <Header style={headerStyle}>Epic Post To Image POC</Header>
          <Content style={contentStyle}>
            <Paragraph style={paragraphStyle}>
              <Text strong style={textStyle}>
                Write a post as you would for a social media platform. Click
                generate to create an image for your post.
              </Text>
              .
            </Paragraph>

            <TextArea
              style={textAreaStyle}
              rows={4}
              placeholder="Write here"
              onChange={handlePostChange}
            />
            <Paragraph style={paragraphStyle}>
              <Text strong style={textStyle}>
                Use Bedrock? (Toggle to use bedrock)
              </Text>

              <Switch onChange={handleBedrockChange} />
            </Paragraph>
            {
              // If the image exist it will be shown
              image && (
                <Image
                  width={"80%"}
                  style={{ marginTop: "10px" }}
                  src={image}
                />
              )
            }
            <Button type="primary" style={buttonStyle} onClick={() => mutate()}>
              Generate
            </Button>
          </Content>
        </Layout>
      </Flex>
    </Spin>
  );
}

上述组件几乎负责处理所有任务,除了请求函数。请求函数位于 utils 文件夹下的 api.ts 文件中,具体代码如下:

import axios from "axios";

const API_URL = process.env.REACT_APP_API_URL;

export async function sendPost(
  url: string,
  data: { text: string; s3Key: string; bedrock: boolean }
): Promise<{ status: number; data: string }> {
  try {
    const response = await axios({
      method: "post",
      url: `${API_URL}${url}`,
      data,
    });

    return {
      status: response.status,
      data: response.data,
    };
  } catch (error: any) {
    return {
      status: error.response?.status || error.status,
      data: error.response,
    };
  }
}

export async function getImage(
  url: string,
  params: { s3Key: string }
): Promise<{ status: number; data: string }> {
  try {
    const response = await axios({
      method: "get",
      url: `${API_URL}${url}`,
      params,
    });

    return {
      status: response.status,
      data: response.data.url,
    };
  } catch (error: any) {
    return {
      status: error.response?.status || error.status,
      data: error.response,
    };
  }
}

React 组件执行的操作如下:

  • 显示一个包含标题和内容的布局。
  • 由用户在文本区域中撰写帖子。
  • 提供切换开关,由用户选择 Amazon Bedrock 或 DALL-E 作为默认模型。
  • 在用户点击“生成”按钮后,将帖子数据发送到后端服务以生成图像。
  • 轮询后端以获取生成的图像,并在获取后立即显示。
  • 在用户请求图像期间,显示加载器,直到用户得到图像为止。

有关前端的完整实现内容,请参阅 代码仓库

前端的最终效果如下:

结果与总结

使用 DALL-E 模型的效果:

DALL-E

使用 Amazon Bedrock 模型的效果:

Amazon Bedrock

就项目本身而言,值得改进的部分有很多。例如,轮询部分可以替换为 WebSocket API。

深入了解

有关本项目所涉内容的更多细节,请参阅以下链接: