Amazon Web Services ブログ
API と OSS 、蓄積したデータで精度を改善するならどちらの基盤モデルを選択すべきか : 質問回答編
Amzon Bedrock や ChatGPT のように API 経由で呼び出せる基盤モデルの精度とコストは実用的なレベルに到達しています。一方で、皆さんが開発している製品やサービス、プロダクトには様々なデータが蓄積されていると思います。そのデータで機械学習モデルを学習できれば、より顧客のニーズに合った体験を提供できます。体験が改善されればより多く顧客が集まり、そこから得られるデータはさらなるモデルの改善につながります。 API で利用できるモデルは追加学習なしに高い精度で推論できるものの、最初から顧客が満足するレベルの精度が出せるとは限りません。例えばカスタマーサポートの応対で使う場合、顧客の言葉の意味を取り違えたり、応対マニュアルと異なる対応や回答を伝えてしまう可能性があります。蓄積されたデータから不適切な回答を抜き出し、分析結果をもとに基盤モデルへの指示方法 ( プロンプト ) やモデル自体を調整できれば、より顧客のニーズに沿った体験を提供できます。
本文書では、蓄積したデータによる精度改善を視野に入れた場合、 API と OSS のどちらがコスト効率が良くなるのかを検証します。 OSS とは、 Hugging Face などで公開されている基盤モデルをカスタマイズし、ホスティングする利用形態を想定しています。学習なしの推論で OSS のモデルが API 経由で利用できるモデルに並ぶのは現時点で難しいですが、学習データがあるなら話は変わってきます。 API の場合はプロンプトの修正 (Prompt Engineering) が精度改善の主な手段ですが、 OSS ではモデルの追加学習 (Fine Tuning: 本記事では Instruction Tuning を指します ) も選択肢になります。最近では対話形式に沿うよう学習 (Instruction Tuning) されたモデルが次々と公開されているため、 素の基盤モデルを用いた前回の検証 よりも少ないデータ量で API に並ぶ精度が得られると期待されます。
実験では、API として Amazon Bedrock で利用できる Claude と Claude Instant の 2 つ、OSS として open-calm-1b 、 japanese-gpt-neox-3.6b-instruction-ppo、bilingual-gpt-neox-4b-instruction-ppo 、ELYZA-japanese-Llama-2-7b-instruct 、Swallow-13b-instruct-hf の 5 つを使用しました。 1b 、 3b 、 7b 、10b クラスのモデルを取りそろえた形です。
モデル名 | 公開元 | 種別 | 概要 |
Claude v2.1 | Anthropic | API | Anthropic の提供する高性能な基盤モデル。 Hugging Face の Leaderboard では GPT-4 などに次ぐ精度。日本語性能でも、 Rakuda Ranking などでトップレベルの性能を示す。 20 万トークンという長大なテキストを扱える。 |
Claude Instant | Anthropic | API | 高速な応答に重点を置いたモデル。 10 万トークンという長大なテキストを扱える。 |
open-calm-1b | CyberAgent | OSS | 株式会社サイバーエージェントから公開された GPT-NeoX ベースの日本語大規模言語モデル。 |
japanese-gpt-neox-3.6b-instruction-ppo | rinna | OSS | rinna 株式会社から公開された日本語で学習された GPT-NeoX ベースの大規模言語モデル。対話形式のデータで教師あり学習、強化学習が行われた日本語大規模言語モデル。 |
bilingual-gpt-neox-4b-instruction-ppo | rinna | OSS | rinna 株式会社から公開された日英両言語で学習された GPT-NeoX ベースの大規模言語モデル。対話形式のデータで教師あり学習、強化学習を行っている。 |
ELYZA-japanese-Llama-2-7b-instruct | ELYZA | OSS | 株式会社 ELYZA から公開された、 Meta の Llama2 をもとに日本語コーパスで継続学習した大規模言語モデル。独自データでの教師あり学習を行っている。 |
Swallow-13b-instruct-hf | 東工大 / 産総研 | OSS | 東工大と産総研の研究チームから公開された、Meta の Llama2 をもとに日本語コーパスで継続学習した大規模言語モデル。 |
Claude 2.1 は GPT-4 、 Claude Instant は ChatGPT 3.5 に近しい精度なので値は読み替えていただけると思います。評価用のデータセットは幅広な選択肢がありますが、今回は JGLUE の中でも質問回答のデータセットである JSQuAD を使用しました。 JSQuAD には、 Wikipedia の文書とそれに関する質問、質問に対する回答箇所が収録されています。近年、検索と生成を組み合わせた RAG (Retrieval Augmented Generation) と呼ばれる手法が注目されていますが、検索結果から要求に応じた情報を抽出するタスクは JSQuAD の形式に近く、 RAG で使用するモデルを選ぶ際の参考指標となり得ると考えたためです。
本記事では、気になる実験結果を先に提示し、のちのセクションで実験の内容について詳細に解説します。
API と OSS のコスト効率の比較結果
コスト効率とは、精度とコストの比率を指します。精度、コストの比較結果を示し最後にコスト効率について示します。
精度は次の図の結果となりました。縦軸は F1 という実際の回答と出力された回答がどれだけオーバーラップしているかを表す指標 ( 後述します ) 、横軸は使用した JSQuAD の学習データの件数です。API である Claude についてはプロンプトに含めた例示 ( データ ) の数、 OSS については追加学習に使用したデータの件数を示します。 API か OSS かで「データ件数」の意味が異なる点、また件数が対数軸になっている点にはご注意ください。追加学習には LoRA の手法を用い、ハイパーパラメーターやエポック数などの設定は各データ件数でそろえていますが、規定時間以内に終わらなかった学習は途中で打ち切っています。
この結果からは、次の 5 つの示唆が得られます (Claude をハイエンドのモデル、 Claude Instant を軽量なモデルと表現しています )。なお、得られる示唆は質問回答タスクに限定される点に注意してください。
1) 10B クラスの日本語 OSS モデルは、ハイエンドの API モデルを利用する場合の精度に匹敵する。
2) 7B クラスの日本語 OSS モデルでも、 30 件以上のデータがあれば追加学習で軽量な API モデルの精度に匹敵する。
3) 7B クラスの日本語 OSS モデルでも、 500 件以上のデータがあれば追加学習でハイエンドの API モデルの精度に匹敵する。
4) 4B 以下のモデルは追加学習しても API 経由のモデルの精度に至らず、また過学習により精度が下がる可能性がある。
5) プロンプトの例示は 2 件あれば十分効果が得られる。
続いて、コストを示します。縦軸は金額 ($) 、横軸はデータ件数です。金額は OSS の場合学習にかかった費用と検証データセットの推論にかかった費用の合計、 API の場合推論にかかった費用のみになります。 API では、ハイエンドのモデルは性能が高い分、やはり費用がかかることがわかります。 OSS では、おおむね 4000 件を超えたあたりから学習コストに対し得られる性能が割に合わなくなることが読み取れます ( 4B の rinna のモデルで少し跳ねがありますが、全体の傾向に影響ないと見ています ) 。先ほどの例から、十分な精度が得られるのは 500 件程度のため、あまり大量の学習データを用意するのは経済的ではないと言えそうです。
最後に、コスト効率の目安として F1 の値を学習コストで割った値をプロットした図を示します。 $1 の学習コストを払うことでどれくらい精度である F1 が向上するか、つまりコスト効率を示す指標になります。 7B のモデルである ELYZA が 32 件の時に突出しており、以後基本的には下がることがわかります。 1B の OpenCALM や 4B の rinna は初期コスト効率が低いものの 500 件近辺で持ち直し、以後は他のモデル含め下降傾向を示しています。そのため、まず 30 件以上、 500 件までは十分な費用対効果が期待できると言えると思います。この知見は、モデルにとって適切なプロンプトを探索したい場合どの程度データ件数をそれぞれ用意すればよいのかの疑問にも示唆を与える結果です。
質問回答タスクに基づく検証から、 API と OSS の使い分けは以下のようにするとよいのではないかという示唆が得られます。なお、文中では冗長性を省くため断定的に書いていますが、あくまで質問回答データセットを用いた本実験から得られる示唆と認識ください 。
- API 経由で利用し精度に課題がある場合、 2 つ程度プロンプトに例示入れることで確かな精度の向上を確認できる。
- Claude Instant 、あるいは ChatGPT 3.5 相当の精度は 7B クラスのモデルを 30 件以上のデータで追加学習することで到達できる。精度が十分であり、安定性や速度の課題がホスティング費用に勝るなら切り替える価値がある。
- Claude 2.1 、あるいは GPT-4 相当の精度が必要な場合、 1) 10B クラスの OSS モデルを使用するか、 2) 7B クラスの OSS モデルを 500 件程度のデータで追加学習し目的の精度が得られるか検証する。精度が十分得られ、安定性や速度の課題がホスティング費用に勝るなら切り替える価値がある。
以下のセクションでは、結論に至るまでの実験設定について記載します。今後、他モデルの数、また要約や分類といった他のタスクについても検証を検討しています。
API と OSS の基盤モデル
現在、基盤モデルを利用する選択肢は大きく分けて API と OSS の 2 つがあります。
API は ChatGPT や Amazon Bedrock のように Web API 経由で基盤モデルを利用する形式です。基盤モデルをホスティングするインフラを意識することなく使うことができ、多くの場合処理したトークン数に応じて費用を支払います。 Anthropic の Claude や OpenAI の ChatGPT など、非常に性能が高いモデルを安価に利用できることが特徴です。トークン数に応じて課金されるため、詳細なプロンプトや例示を書くほどコストがかかることになります。
OSS は Hugging Face などで公開されているオープンソースのモデルを GPU インスタンス等にホスティングして利用します。推論用のコードやインフラを準備する手間があるものの、一度立ち上げてしまえば API リクエスト数や処理トークン数を意識することなく使用できます。 また、 API 提供者の設定するリクエスト制限、サービス停止などの影響を受けることもありません。Amazon SageMaker JumpStart を使えば、モデルのホスティングもボタン操作のみで行えます。現在、 rinna と Stability AI の日本語モデルが掲載されており、今後も増える予定です。 OSS であるため、手元のデータで追加学習しカスタマイズすることもできます ( 追加学習、また追加学習後のモデルの利用と公開についてはライセンスを注意深く確認してください ) 。カスタマイズは API でもできるようになってきていますが、 Amazon Bedrock では時間単位課金のモデルユニットが必要であり、 ChatGPT でもカスタマイズしたモデルを使うときの料金は 3 倍近くになります (2024/1/31 時点) 。追加学習後の推論のコスト効率を考える場合、 OSS のモデルは良い選択肢になるでしょう。
評価手法 : データセット
今回は評価に JSQuAD という質問回答のデータセットを使用します。中のデータは次のような形式をしています。SQuAD (The Stanford Question Answering Dataset)というデータセットを参考に作成されており、 Wikipedia の記事 ( context
) に対する質問 ( question
) と回答 ( answers
) が収録されています (answers
は 1 件のみです)。 SQuAD2.0 では context
に答えがない場合 is_impossible : true
のケースが存在しますが、 JSQuAD は SQuAD 1.1 をベースにしており答えられない質問は含まれていません。
{
"title": "東海道新幹線 (Tokaido Shinkansen)",
"paragraphs": [
{
"qas": [
{
"question": "2020年(令和2年)3月現在、東京駅 - 新大阪駅間の最高速度はどのくらいか。 (What is the maximum speed between Tokyo Station and Shin-Osaka Station as of March 2020?)",
"id": "a1531320p0q0",
"answers": [
{
"text": "285 km/h",
"answer_start": 182
}
],
"is_impossible": false
},
{
..
}
],
"context": "東海道新幹線 [SEP] 1987年(昭和62年)4月1日の国鉄分割民営化により、JR東海が運営を継承した。西日本旅客鉄道(JR西日本)が継承した山陽新幹線とは相互乗り入れが行われており、東海道新幹線区間のみで運転される列車にもJR西日本所有の車両が使用されることがある。2020年(令和2年)3月現在、東京駅 - 新大阪駅間の所要時間は最速2時間21分、最高速度285 km/hで運行されている。"
}
]
}
1 つの context
には複数の質問があります。同じ context
のデータは類似性が高いため、学習データを作る際は context
が重複しないようにしています ( データ数が 15,000 件までは重複させないことができました ) 。
JSQuAD の学習データからランダムにサンプリングしデータ件数ごとの学習データを作成します。データを全く与えない場合を 0、そこから 2 、 4 、 8 、 16 ・・・と 2 の倍数刻みで 8192 件まで context
の重複がないようランダムに選択したデータセットを用意しました。さらに、 15652 件 (context
重複なしで作れる最大のデータ数 )、31005 件( context
重複を最大 2 回許可)、45576 件(context
重複を最大 3 回許可)、57086 件(context
重複を最大 4 回許可)、62859 件(context
重複が最大 5 件、 JSQuAD 内の全質問)のデータセットを用意しました。件数ごとのデータセットについて、 API はプロンプト内の例示に、 OSS ではモデルの追加学習に使用します。データが少なすぎると追加学習が困難なので、 追加学習は 8 件からスタートをしています。逆に、 API のプロンプトに入力できる事例数は 8 件を上限にしています。今回、評価データセット全件を評価するために Amazon Bedrock で Preview 中の Batch inference の機能 を使用したのですが、その容量の制限上この値に留めています。 128 件入れた研究もあり (Cold-Start Data Selection for Few-shot Language Model Fine-tuning: A Prompt-Based Uncertainty Propagation Approach ) 、 Claude は 20 万トークンものサイズを扱えるので GA した際には推論可能な量も増えることを期待しています。
評価手法 : 精度の算出
JSQuAD の問題を基盤モデルで解き精度を計測するにはどうすればよいでしょうか ? 例えば、 Claude の場合は次のようにプロンプトを与えています。 input には context
、instruction には question
が入ります。基盤モデルの回答と answers
が一致するかで精度を評価します。
Human: 与えられたinputからinstructionに対する回答を抽出する関数を実行してください。
入出力のexampleを示します。
<example>
<input>・・・</input>
<instruction>・・・</instruction>
Answer:xxxxx
</example>
<example>
<input>・・・</input>
<instruction>・・・</instruction>
Answer:xxxxx
</example>
次のinputからinstructionに対する回答を抽出してください。結果はAnswer:の後に記載し名詞以外何も含まないことを厳守してください。
<input>・・・</input>
<instruction>・・・</instruction>
Assistant:Answer:
精度は回答と完全に一致した Exact Match の数、予測の中にどれだけ回答と一致する文字が含まれるかを計測する F1 で行います。 JSQuAD の評価において、 F1 は文字単位で計算します。英語では単語単位なのですが、日本語の場合形態素解析の仕方で評価が揺らいでしまうためです。 「 285 km/h 」が正解の場合、 “285 km/h” と完全に回答できれば Exact Match ですが “285” だと Exact Match にはなりません。 F1 の場合 “2” “8” “5” は一致していると評価されるのでより寛容な評価になります。ただ、今回評価に使用した lm-evaluation-harness では JSQuAD の評価を単語単位で計算しているので、本記事の F1 の値は本来の値とは少し異なります。上記は Claude の例を示しましたが、プロンプトの形式は各 API / OSS の形式に準じます。例えば、 Claude であれば \n\nHuman:
と \n\nAssistant:
の対話形式に、 OSS では例えば rinna のモデルであれば "ユーザー: "
or "システム: "
の形式に合わせます。 lm-evaluation-harness ではタスクの指定でプロンプトのテンプレートを切り替えることができます。例えば、 rinna 用のテンプレートは jsquad-1.1-0.4 になります。
評価手法 : コストの算出
コストはどう評価すればよいでしょうか ? 学習と推論の 2 つのコストがあります。 OSS の場合は学習と推論にかかったコスト、 API の場合は推論のみにかかったコストが計上対象となります。 OSS のモデルの学習は、 OpenCALM の検証の時と同じ実装を使用しました。 aws-ml-jp 上のサンプルを使用することで、 Notebook を実行するのみで簡単に Hugging Face 上のモデルを LoRA 形式で追加学習できます。テンプレートはモデルごと合ったものを使用し、 3 エポック学習を回しています。学習に使用したインスタンスは NVIDIA A10G の GPU が搭載された g5.2xlarge なので、インスタンスの時間当たりの単価と学習にかかった時間をかけ合わせればコストが計算できます。次の図は ELYZA の学習時間を示した図です。
価格を計算する際は、オンデマンド価格を参照しました ( 執筆時点で $1.212 / 時間 ) 。精度が十分な値になる 512 件では学習に 15 分程度、金額にして 44 円ぐらいになります。小学校の遠足のおやつ代は平均 426 円とのことなので、約 1/10 の金額でモデルが遠足に行って成長して帰ってくると考えると割安と感じます。スポットインスタンスなどを利用すればより安価になります。推論も同様にかかった時間に単価をかけて計算しましたが、 Swallow のみ g5.2xlarge に乗せるため 8bit の量子化をしています。
API のコストは、入出力のトークン数から計算します。ただし、評価用データ (validation dataset) が 4000 件近くあり、普通に 1 件 1 件推論しているとあっという間に 1 分当たりのリクエスト上限に達してしまいます。そのため、 Preview 中の Batch inference の機能を使用しました。次に Batch inference のサンプルコードを示します。 Amazon Simple Storage Service にデータをアップロードし、 create_model_invocation_job
を実行します。roleArn
は Permission を参照し作成したロールの arn を設定します。
import time
import boto3
bedrock = boto3.client(service_name="bedrock")
inputDataConfig = ({
"s3InputDataConfig": {
"s3Uri": "s3://input-bucket/input/abc.jsonl"
}
})
outputDataConfig = ({
"s3OutputDataConfig": {
"s3Uri": "s3://output-bucket/output/"
}
})
response = bedrock.create_model_invocation_job(
roleArn="arn:aws:iam::123456789012:role/MyBatchInferenceRole",
modelId="amazon.titan-text-express-v1",
jobName="my-batch-job",
inputDataConfig=inputDataConfig,
outputDataConfig=outputDataConfig
)
job_id = response.get('jobArn')
status = 'Begin'
while status not in ('Completed', 'Failed', 'Stopped'):
time.sleep(5)
status = bedrock.get_model_invocation_job(jobIdentifier=job_id)[
"status"
]
定期的に get_model_invocation_job
を実行しステータスを確認します。私が実行したところ、ジョブが実行開始になるまで ( Submitted
から InProgress
になるまで ) 数時間待たされることもあり、 Preview 以後改善されることを期待しています。
Batch inference は推論結果だけでなく、そのバッチで入出力されたトークン数が出力されます。以下は、 2 shot でプロンプトを書いたときの結果です。入出力のトークン数にそれぞれの単価をかけることでコストを算出できます。
{
"processedRecordCount":4442,
"successRecordCount":4442,
"errorRecordCount":0,
"inputTokenCount":3418270,
"outputTokenCount":50038
}
実験コードは以下で公開しています。 Batch inference の実装サンプルはまだ少ないと思いますので、参考にしていただければ幸いです! https://github.com/aws-samples/aws-ml-jp/pull/66
実験結果と今後の展望
実験結果は冒頭ご紹介した通りです。どれぐらいのデータがあれば OSS で公開されているモデルを追加学習し期待の精度が得られるのか、参考となる指標を示すことができたと思います。インスタンスをホスティングする場合常時稼働する点がネックになりますが、文書解析などリアルタイムで行う必要がないもの、ゲームなどで多様なキャラクターのバリエーションが必要でアクセス頻度も高い場合など、カスタマイズ性とホスティングによるレスポンスの速度や安定性が光るシーンは多々あると考えています。また、モデルを量子化しさらに C++ で最適化することで AWS Lambda 上で推論するなどできれば、実質サーバーレスの API と同等の体験が実現できるでしょう。
今後は、他タスクでの検証、また 7B のモデルを軸にサーバーレス形式での推論が行えないかなどを深堀できればと考えています。
著者プロフィール