Generative AI

NeMo 2.0 を使った VLM 開発: ファインチューニングから推論、評価

Reading Time: 10 minutes

NeMo Framework とは

NeMo Framework は、LLM をはじめ、生成 AI モデルを構築、カスタマイズするためのクラウドネイティブなフレームワークです。NGC 上にコンテナーが公開されており、すぐに利用を開始することができます。

NeMo Framework は、NGC 上に公開されているコンテナーを無償利用していただくこともできますが、NVIDIA AI Enterprise の対象ソフトウェアとなっているため、エンタープライズ サポートを希望される場合は NVIDIA AI Enterprise ライセンスの購入をご検討ください。

生成 AI のワークフロー

図 1. 生成 AI の開発ワークフロー

生成 AI の開発におけるタスクには以下のようなものがあります。

  • 事前学習に必要な大規模データの準備
  • 分散学習を利用した 生成 AI モデルの事前学習
  • モデル をカスタマイズするためのファインチューニングやアライメントおよびプロンプト エンジニアリング
  • モデルの推論を高速化するための最適化
  • GPU を最大限に活用したモデルのサービング
  • コストを抑えながらモデルに最新情報を反映させるための RAG
  • 生成 AI アプリケーションの意図しない挙動を抑えるためのガードレール

生成 AI モデル の開発、サービスでの利用には多くのステップが必要になりますが、NeMo Framework コンテナーには、データの準備からモデルの学習、カスタマイズに必要な下記ライブラリが含まれており、これらを使用することでモデルの構築に関するステップを 1 つのコンテナー環境で実行できます。

  • NeMo Curator
    LLM の学習に必要な大規模データセットのダウンロードから抽出、クリーニング、フィルタリングなどを行うためのスケーラブルなツールキット。
  • NeMo
    LLM、マルチモーダル、音声などの生成 AI モデルを構築するためのスケーラブルなフレームワーク。
  • Megatron-LM
    Transformer モデルの大規模学習に関する研究プロジェクト。このリポジトリ内の Megatron-Core が NeMo で使用されている。
  • Transformer Engine
    FP8 を中心とした Transformer モデルを高速化させるツールキット。Megatron-Core で使用されている。

これらのライブラリは、GitHub 上に OpenSource として公開されていますが、依存関係が解消されている NeMo Framework コンテナーから利用することをお薦めします。コンテナーの場合、/opt ディレクトリ配下に上記のモジュールが配置されています。

NeMo 2.0

NeMo 1.0 や NeMo 2.0 という名は API の呼称であり、NeMo リポジトリのバージョンを指したものではありません。

これまでのNeMo (1.0) は事前学習や SFT (Supervised Fine-Tuning)、PEFT (Parameter-Efficient Fine-Tuning) などそれぞれのジョブに対応したコンフィグ ファイル (YAML ファイル) を編集し、ジョブ スクリプトを実行するインターフェイスをとってきました。このアプローチは実験の設定を宣言的にして 1 つのスクリプトを実行するだけでジョブが完結できる利点がありましたが、一方で柔軟性やプログラム制御の面では制限がありました。

NeMo 2.0 では、Python ベースのコンフィグに移行することで開発者へ以下の利点を提供します:

  • コンフィグに対するより高い柔軟性と制御性
  • IDE との統合が向上し、コード補完や型チェックが利用可能
  • プログラムによるコンフィグの拡張やカスタマイズが容易

具体的な例として以下に NeMo 1.0 の API を利用した LoRA (PEFTの 1 つ) の実行スクリプトの一部を挙げます。これまでは実行スクリプト + Hydra を使用して、ベースとなっているコンフィグ ファイルを上書き (もしくはコンフィグ ファイルを直接書き換え) することで実験の構成を変更していました。

torchrun --nproc_per_node=1 \
/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
    exp_manager.exp_dir=${EXP_DIR} \
    exp_manager.name=${EXP_NAME} \
    trainer.precision=bf16 \
    trainer.devices=1 \
    trainer.num_nodes=1 \
    trainer.max_steps=100 \
    trainer.val_check_interval=100 \
    model.restore_from_path=${MODEL} \
    model.peft.peft_scheme="lora" \
    model.tensor_model_parallel_size=${TP_SIZE} \
    model.pipeline_model_parallel_size=${PP_SIZE}

これが NeMo 2.0 API では以下のようになります。ここでは NeMo 1.0 API と対応関係のある recipe を使った方法を提示します。recipe は NeMo 1.0 のコンフィグ ファイルと同等の位置付けで recipe 内にはモデルごとにデフォルトの構成が用意されています。Python スクリプト内からアクセス可能で必要に応じて編集、拡張することができます。

import nemo_run as run
from nemo.collections import llm


def configure_recipe(nodes: int = 1, gpus_per_node: int = 1):
    recipe = llm.llama3_8b.finetune_recipe(
        dir="/checkpoints/llama3_finetuning",  # Path to store checkpoints
        name="llama3_lora",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        peft_scheme="lora",
    )

    recipe.trainer.max_steps = 100
    recipe.trainer.num_sanity_val_steps = 0

    # Async checkpointing doesn't work with PEFT
    recipe.trainer.strategy.ckpt_async_save = False

    # Need to set this to 1 since the default is 2
    recipe.trainer.strategy.context_parallel_size = 1
    recipe.trainer.val_check_interval = 100

    # This is currently required for LoRA/PEFT
    recipe.trainer.strategy.ddp = "megatron"

    return recipe


if __name__ == "__main__":
    run.run(configure_recipe(), executor=run.LocalExecutor())

Python をベースとしたインターフェイスの導入により、NeMo 1.0 API に比べて、コンフィグをより柔軟に扱うことが可能になり、またカスタマイズしたモジュールを使用することもより容易になりました。

VLM (Vision Language Model) とは

VLM は、「画像や動画などの視覚情報」と「テキストなどの言語情報」を同時に処理できる生成 AI モデルです。これまでの画像認識モデルや言語モデルでは画像だけ、またはテキストだけを扱うのが一般的でしたが、VLM ではこの 2 つを統合して扱うことで、例として「画像の内容を説明する文章を生成する」、「画像に関する質問にテキストで答える」といった高度なタスクをこなせます。

NeMo 2.0 では VLM の学習に必要な機能が揃っており、開発者は容易に VLM のカスタマイズを実行できます。

VLM チュートリアル

本記事では meta-llama/Llama-3.2-11B-Vision-Instruct に対して日本語の VLM Instruction データセットである llm-jp/llava-instruct-ja を使用した SFT を実行し、推論、評価までを行います。

本チュートリアルでの手順は以下の通りです。

  • 事前準備
  • NeMo Framework のコンテナーを起動
  • モデルのダウンロードと NeMo フォーマットへ変換
  • データセットの準備
  • SFT の実行
  • HuggingFace フォーマットへ変換
  • VLM の推論
  • VLM の評価

また、今回のチュートリアルの検証環境は以下の条件で行っております。

  • ハードウェア
    • DGX H100
    • GPU: 8 x NVIDIA H100 80 GB GPUs (driver version: 535.230.02)
    • CPU: Intel(R) Xeon(R) Platinum 8480C
    • システム メモリ: 2 TB
  • ソフトウェア
    • OS: Ubuntu 22.04.5 LTS
    • Container: nvcr.io/nvidia/nemo:25.04

事前準備

以下のコマンドで作業用のディレクトリを作成し、移動します。

mkdir vlm-example
cd vlm-example

Docker コンテナーの起動

以下のコマンドでコンテナーを起動します。

sudo docker run --rm -it --gpus all --shm-size=16g --ulimit memlock=-1 --network=host -v ${PWD}:/workspace -w /workspace nvcr.io/nvidia/nemo:25.04 bash

モデルのダウンロードと NeMo フォーマットへの変換

このチュートリアルでは、meta-llama/Llama-3.2-11B-Vision-Instruct を使用します。このモデルは、アクセス許可が必要なため、huggingface のアカウントで許可をとった後に以下の環境変数に自身のトークンを設定します。

export HF_TOKEN="hf_YOUR-HuggingFace-Token" # Change this to your Huggingface token

NeMo 1.0 ではモデルをダウンロードした後に変換スクリプトを使用して NeMo フォーマットへチェックポイントを変換する必要がありましたが、NeMo 2.0 ではサポートしているモデルに関しては以下のコマンドでダウンロードから変換までを実行することができます。

以下のスクリプトは IPython や Jupyter Notebook 上での実行をサポートしていないことに注意してください。

from nemo.collections import vlm
from nemo.collections.llm import import_ckpt

if __name__ == "__main__":
    import_ckpt(
        model=vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct()),
        source="hf://meta-llama/Llama-3.2-11B-Vision-Instruct",
        output_path="/workspace/models/Llama-3.2-11B-Vision-Instruct",
    )

実行が完了すると以下のようなログが表示され、モデルのダウンロードから変換までが完了していることが確認できます。

✓ Checkpoint imported to /workspace/models/Llama-3.2-11B-Vision-Instruct
Imported Checkpoint
├── context/
│   ├── nemo_tokenizer/
│   │   ├── special_tokens_map.json
│   │   ├── tokenizer.json
│   │   └── tokenizer_config.json
│   ├── io.json
│   └── model.yaml
└── weights/
    ├── .metadata
    ├── __0_0.distcp
    ├── __0_1.distcp
    ├── common.pt
    └── metadata.json

データセットの準備

このチュートリアルでは llm-jp/llava-instruct-ja を使用します。このデータセットには追加で画像データセットのダウンロードが必要になるため、以下のスクリプトで併せてダウンロードします (画像データセットのダウンロードおよび解凍には時間がかかります)。

import os
import zipfile
import requests

from datasets import load_dataset
from tqdm import tqdm

DATA_DIR = "/workspace/dataset/"
DATASET_ID = "llm-jp/llava-instruct-ja"
URL = "http://images.cocodataset.org/zips/train2017.zip"
output_file = os.path.join(DATA_DIR, "train2017.zip")


if __name__ == "__main__":
    os.makedirs(DATA_DIR, exist_ok=True)
    dataset = load_dataset(DATASET_ID)
    dataset["train"].to_pandas().to_json(f"{DATA_DIR}/llava_instruct_ja_156k.json", orient="records", force_ascii=False, indent=2)
    dataset = dataset["train"].train_test_split(test_size=0.05, shuffle=True, seed=42)
    dataset["train"].to_json(f"{DATA_DIR}/training.jsonl", force_ascii=False)
    dataset["test"].to_json(f"{DATA_DIR}/validation.jsonl", force_ascii=False)

    response = requests.get(URL, stream=True)
    if response.status_code == 200:
	 total_size = int(response.headers.get('content-length', 0))
        block_size = 1024  # 1KB

	 with open(output_file, "wb") as file, tqdm(
            desc="Downloading train2017.zip",
            total=total_size,
            unit='B',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=block_size):
                if chunk:
                    file.write(chunk)
                    bar.update(len(chunk))
        print(f"Downloaded {output_file} successfully!")
    else:
        print("Failed to download the file.")

    extract_to = os.path.join(DATA_DIR, "coco/")
    with zipfile.ZipFile(output_file, "r") as zip_ref:
        zip_ref.extractall(extract_to)
        print(f"Extracted files to {extract_to}")

SFT の実行

モデルとデータの準備が整ったので SFT を実行します。最後に完全なスクリプトを提示しますが、ここではスクリプトを少し分解しながら説明します。

まずはカスタム データを扱う準備です。カスタム データを扱う際にはここで使用する preloaded と後述する Megatron-Energon が利用できます。preloaded は現時点 (25.04 コンテナー) では学習データと検証データを分けて入力できません (内部でも分割してくれません)。このことからここではカスタムの preloaded データ モジュールを作ることでこの問題に対処します。その方法は元のデータ モジュールをオーバーライドして学習データと検証データを受け取れるように変更することです。これはあくまで一例であり、開発者の方の目的に合うように自由に変更することができます。

# The following is an example of a custom dataset configuration.
data_config = vlm.ImageDataConfig(
    image_folder="/workspace/dataset",
    conv_template="mllama",  # Customize based on your dataset needs
    )


class CustomMLlamaPreloadedDataModule(vlm.MLlamaPreloadedDataModule):
    def __init__(
        self,
        paths: str | List[str],
        weights: Optional[List[float]] = None,
        data_config: Optional[DataConfig] = ImageDataConfig,
        seq_length: int = 2048,
        decoder_seq_length: Optional[int] = None,
        tokenizer: Optional = None,
        image_processor: Optional = None,
        micro_batch_size: int = 4,
        global_batch_size: int = 8,
        num_train_samples: int = 10_000,
        num_val_samples: int = 10_000,
        num_test_samples: int = 10_000,
        num_workers: int = 8,
        pin_memory: bool = True,
        persistent_workers: bool = False,
        use_packed_sequence: bool = False,
        seed: int = 1234,
    ) -> None:
        super().__init__(paths=paths)
        if not isinstance(paths, (list, tuple)):
            paths = [paths]
        if weights is not None:
            assert len(weights) == len(paths)
            if len(weights) == 1:
                # weights must be None if there is only one dataset
                weights = None

        self.paths = paths
        self.weights = weights
        self.data_config = data_config
        self.seq_length = seq_length
        self.decoder_seq_length = decoder_seq_length
        self.micro_batch_size = micro_batch_size
        self.global_batch_size = global_batch_size
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.num_train_samples = num_train_samples
        self.num_val_samples = num_val_samples
        self.num_test_samples = num_test_samples
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.seed = seed
        self.use_packed_sequence = use_packed_sequence
        self.init_global_step = 0
        self.tokenizer = tokenizer
        self.image_processor = image_processor

        if tokenizer is None or image_processor is None:
            logging.warning(
                "Processor and tokenizer are not provided! Fall back to `meta-llama/Llama-3.2-11B-Vision-Instruct`."
            )
            from transformers import AutoProcessor

            processor = AutoProcessor.from_pretrained(
                "meta-llama/Llama-3.2-11B-Vision-Instruct"
            )
            self.tokenizer = tokenizer or processor.tokenizer
            self.image_processor = image_processor or processor.image_processor

        self.data_sampler = MegatronDataSampler(
            seq_len=self.seq_length,
            decoder_seq_len=self.decoder_seq_length,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            dataloader_type="cyclic",
        )

    def setup(self, stage: str = "") -> None:
        # assert len(self.paths) == 1, "not yet support blend dataset in MLlama 2.0!"
        if self.use_packed_sequence:
            pass  # TODO
        else:
            # TODO:
            # rng = torch.Generator().manual_seed(self.seed)
            # train_dataset, val_dataset, test_dataset =
            # random_split(dataset, [train_size, val_size, test_size], generator=rng)
            self._train_ds = vlm.mllama.data.preloaded.MLlamaDataset(
                self.paths[0],
                self.data_config,
                self.tokenizer,
                self.image_processor,
                self.seq_length,
            )
            self._validation_ds = vlm.mllama.data.preloaded.MLlamaDataset(
                self.paths[1],
                self.data_config,
                self.tokenizer,
                self.image_processor,
                self.seq_length,
            )

次に NeMo 1.0 API のコンフィグ (YAMLファイル) にあたる recipe を定義します。各モデルごとに recipe にはデフォルトの値が設定されていますが、開発者の方はこれを Python 上で変更できます。一部の設定は run.Config という NeMo-Run の機能を使用するため、少しとっつきにくいかもしれませんが、その際は NeMo-Run リポジトリにあるこちらのガイドやよくある質問がヒントになるかもしれません。

def configure_recipe(nodes: int = 1, gpus_per_node: int = 8):
    recipe = vlm.mllama_11b.finetune_recipe(
        name="mllama_11b_finetune",
        dir="/workspace/results",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        peft_scheme="none",  # 'lora', 'none'
        )

    recipe.resume.restore_config = run.Config(
        nl.RestoreConfig,
        path="/workspace/models/Llama-3.2-11B-Vision-Instruct",
    )

    strategy = run.Config(
        nl.MegatronStrategy,
        tensor_model_parallel_size=2,
        pipeline_model_parallel_size=1,
        encoder_pipeline_model_parallel_size=0,
        pipeline_dtype=torch.bfloat16,
        )
    recipe.trainer.strategy=strategy

    recipe.trainer.max_steps = 1160
    recipe.trainer.log_every_n_steps = 1
    recipe.trainer.limit_val_batches = 1.0
    recipe.trainer.val_check_interval = 580
    recipe.log.ckpt.train_time_interval = None
    recipe.optim = distributed_fused_adam_with_cosine_annealing(
        max_lr=2.0e-5,
        min_lr=2.0e-07,
        warmup_steps=100,
        )

    recipe.data = run.Config(
        CustomMLlamaPreloadedDataModule,
        paths=[
            "/workspace/dataset/training.jsonl",
            "/workspace/dataset/validation.jsonl",
        ],  # Path to your llava-like dataset
        data_config=data_config,
        seq_length=6404,
        decoder_seq_length=2048,
        global_batch_size=128,
        micro_batch_size=1,
        tokenizer=None,  # Define your tokenizer if needed
        image_processor=None,  # Add an image processor if required
        num_workers=2,  # Number of workers for data loading
    )

    return recipe

このチュートリアルでは容易さとカスタマイズの自由度のバランスがとれた recipe を使った方法を取り上げていますが、NeMo 2.0 ではより低レベルの API を提供する finetune API を使用する方法やターミナルから 1 ラインで実行できる NeMo-Run CLI を使った方法なども提供しています (開発速度がそれぞれ異なることには注意が必要です)。

図 2. NeMo 2.0 の実行オプション

最後にスクリプトの全体を示します。

import logging
from typing import Any, Dict, List, Optional, Sequence

import torch
import nemo_run as run
from nemo import lightning as nl
from nemo.collections import vlm
from nemo.collections.llm.recipes.optim.adam import (
    distributed_fused_adam_with_cosine_annealing,
)
from nemo.collections.vlm.neva.data.config import DataConfig, ImageDataConfig
from nemo.lightning.pytorch.plugins import MegatronDataSampler


# The following is an example of a custom dataset configuration.
data_config = vlm.ImageDataConfig(
    image_folder="/workspace/dataset",
    conv_template="mllama",  # Customize based on your dataset needs
    )


class CustomMLlamaPreloadedDataModule(vlm.MLlamaPreloadedDataModule):
    def __init__(
        self,
        paths: str | List[str],
        weights: Optional[List[float]] = None,
        data_config: Optional[DataConfig] = ImageDataConfig,
        seq_length: int = 2048,
        decoder_seq_length: Optional[int] = None,
        tokenizer: Optional = None,
        image_processor: Optional = None,
        micro_batch_size: int = 4,
        global_batch_size: int = 8,
        num_train_samples: int = 10_000,
        num_val_samples: int = 10_000,
        num_test_samples: int = 10_000,
        num_workers: int = 8,
        pin_memory: bool = True,
        persistent_workers: bool = False,
        use_packed_sequence: bool = False,
        seed: int = 1234,
    ) -> None:
        super().__init__(paths=paths)
        if not isinstance(paths, (list, tuple)):
            paths = [paths]
        if weights is not None:
            assert len(weights) == len(paths)
            if len(weights) == 1:
                # weights must be None if there is only one dataset
                weights = None

        self.paths = paths
        self.weights = weights
        self.data_config = data_config
        self.seq_length = seq_length
        self.decoder_seq_length = decoder_seq_length
        self.micro_batch_size = micro_batch_size
        self.global_batch_size = global_batch_size
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.num_train_samples = num_train_samples
        self.num_val_samples = num_val_samples
        self.num_test_samples = num_test_samples
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.seed = seed
        self.use_packed_sequence = use_packed_sequence
        self.init_global_step = 0
        self.tokenizer = tokenizer
        self.image_processor = image_processor

        if tokenizer is None or image_processor is None:
            logging.warning(
                "Processor and tokenizer are not provided! Fall back to `meta-llama/Llama-3.2-11B-Vision-Instruct`."
            )
            from transformers import AutoProcessor

            processor = AutoProcessor.from_pretrained(
                "meta-llama/Llama-3.2-11B-Vision-Instruct"
            )
            self.tokenizer = tokenizer or processor.tokenizer
            self.image_processor = image_processor or processor.image_processor

        self.data_sampler = MegatronDataSampler(
            seq_len=self.seq_length,
            decoder_seq_len=self.decoder_seq_length,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            dataloader_type="cyclic",
        )

    def setup(self, stage: str = "") -> None:
        # assert len(self.paths) == 1, "not yet support blend dataset in MLlama 2.0!"
        if self.use_packed_sequence:
            pass  # TODO
        else:
            # TODO:
            # rng = torch.Generator().manual_seed(self.seed)
            # train_dataset, val_dataset, test_dataset =
            # random_split(dataset, [train_size, val_size, test_size], generator=rng)
            self._train_ds = vlm.mllama.data.preloaded.MLlamaDataset(
                self.paths[0],
                self.data_config,
                self.tokenizer,
                self.image_processor,
                self.seq_length,
            )
            self._validation_ds = vlm.mllama.data.preloaded.MLlamaDataset(
                self.paths[1],
                self.data_config,
                self.tokenizer,
                self.image_processor,
                self.seq_length,
            )


def configure_recipe(nodes: int = 1, gpus_per_node: int = 8):
    recipe = vlm.mllama_11b.finetune_recipe(
        name="mllama_11b_finetune",
        dir="/workspace/results",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        peft_scheme="none",  # 'lora', 'none'
        )

    recipe.resume.restore_config = run.Config(
        nl.RestoreConfig,
        path="/workspace/models/Llama-3.2-11B-Vision-Instruct",
    )

    strategy = run.Config(
        nl.MegatronStrategy,
        tensor_model_parallel_size=2,
        pipeline_model_parallel_size=1,
        encoder_pipeline_model_parallel_size=0,
        pipeline_dtype=torch.bfloat16,
        )
    recipe.trainer.strategy=strategy

    recipe.trainer.max_steps = 1160
    recipe.trainer.log_every_n_steps = 1
    recipe.trainer.limit_val_batches = 1.0
    recipe.trainer.val_check_interval = 580
    recipe.log.use_datetime_version = False
    recipe.log.ckpt.train_time_interval = None
    recipe.optim = distributed_fused_adam_with_cosine_annealing(
        max_lr=2.0e-5,
        min_lr=2.0e-07,
        warmup_steps=100,
        )

    recipe.data = run.Config(
        CustomMLlamaPreloadedDataModule,
        paths=[
            "/workspace/dataset/training.jsonl",
            "/workspace/dataset/validation.jsonl",
        ],  # Path to your llava-like dataset
        data_config=data_config,
        seq_length=6404,
        decoder_seq_length=2048,
        global_batch_size=128,
        micro_batch_size=1,
        tokenizer=None,  # Define your tokenizer if needed
        image_processor=None,  # Add an image processor if required
        num_workers=2,  # Number of workers for data loading
    )

    return recipe


if __name__ == "__main__":
    print(configure_recipe())
    run.run(configure_recipe(), direct=True)

ジョブの実行が完了すると results/mllama_11b_finetune/checkpoints にチェックポイントが格納されています(検証環境では数時間かかりました)。

(オプション): Megatron-Energon を使ったアプローチ

Megatron-Energon は特にマルチモーダル モデルにおいて、大規模な分散学習環境で効率的に処理できるように設計されたデータ ローダーです。前述の preloaded は全てのデータをメモリ上にロードするため、大量のデータを扱う際には理想的な方法ではありません。ここでは Megatron-Energon を NeMo の VLM カスタマイズに使用する方法を紹介します。

まずは以下のスクリプトでデータを Megatron-Energon がサポートしている WebDataset 形式に変換します。

import json
import os
import random

import webdataset as wds
from tqdm import tqdm

# Set the path to the LLaVA-Pretrain dataset directory
dataset_dir = "/workspace/dataset"

# Paths to the dataset files
json_file = os.path.join(dataset_dir, "llava_instruct_ja_156k.json")
output_path = os.path.join(dataset_dir, "wds")

if not os.path.exists(output_path):
    os.mkdir(output_path)

# Load data
with open(json_file, "r") as f:
    data = json.load(f)

random.shuffle(data)

# Convert JSON to WebDataset
with wds.ShardWriter(
    os.path.join(output_path, "instruction-%d.tar"), maxcount=10000
) as shard_writer:
    for entry in tqdm(data):
        with open(os.path.join(dataset_dir, entry["image"]), "rb") as img_file:
            image_data = img_file.read()
        sample = {
            "__key__": entry["id"],
            "jpg": image_data,
            "json": json.dumps(entry["conversations"], ensure_ascii=False).encode("utf-8"),
        }
        shard_writer.write(sample)

print("Dataset successfully converted to WebDataset format.")

次にメタデータを作成します。メタデータの作成はコマンドが用意されており、以下のようにインタラクティブに実行していきます。

> energon prepare /workspace/dataset/wds

Found 16 tar files in total. The first and last ones are:
- instruction-0.tar
- instruction-9.tar
If you want to exclude some of them, cancel with ctrl+c and specify an exclude filter in the command line.
Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1":0.95,0.05,0 

Indexing shards  [####################################]  16/16          
Sample 0, keys:
 - jpg
 - json
Json content of sample 0 of instruction-0.tar:
[
  {
    "from": "human",
    "value": "<image>\n\u753b\u50cf\u306b\u306f\u4f55\u304c\u5199\u3063\u3066\u3044\u307e\u3059\u304b\uff1f"
  },
  {
    "from": "gpt",
    "value": "\u753b\u50cf\u306b\u306f\u3055\u307e\u3056\u307e\u306a\u7a2e\u985e\u306e\u98df\u6750\u304c\u76db\u3089\u308c\u305f\u660e\u308b\u3044\u8272\u306e\u30d7..."
  },
  {
    "from": "human",
    "value": "\u3053\u306e\u753b\u50cf\u306b\u6620\u3063\u3066\u3044\u308b\u98df\u6750\u306f\u4f55\u7a2e\u985e\u3042\u308a\u307e\u3059\u304b\uff1f"
  },
  "..."
]
Sample 1, keys:
 - jpg
 - json
Json content of sample 1 of instruction-0.tar:
[
  {
    "from": "human",
    "value": "\u3053\u306e\u753b\u50cf\u306b\u306f\u4f55\u982d\u306e\u30ad\u30ea\u30f3\u304c\u5199\u3063\u3066\u3044\u307e\u3059\u304b\uff1f\n<im..."
  },
  {
    "from": "gpt",
    "value": "\u753b\u50cf\u306b\u306f\u5c11\u306a\u304f\u3068\u30822\u982d\u306e\u30ad\u30ea\u30f3\u304c\u5199\u3063\u3066\u3044\u307e\u3059\u30021\u982d..."
  },
  {
    "from": "human",
    "value": "\u30ad\u30ea\u30f3\u306f\u4f55\u3092\u98df\u3079\u3066\u3044\u307e\u3059\u304b\uff1f"
  },
  "..."
]
Found the following part types in the dataset: json, jpg
Do you want to create a dataset.yaml interactively? [Y/n]:Y 

The following sample types are available:
0. CaptioningSample
1. ImageClassificationSample
2. ImageSample
3. InterleavedSample
4. MultiChoiceVQASample
5. OCRSample
6. Sample
7. SimilarityInterleavedSample
8. TextSample
9. VQASample
10. VidQASample
11. Crude sample (plain dict for cooking)
Please enter a number to choose a class: 9

The sample type you selected:

@dataclass
class VQASample(Sample):
    """Sample type for visual question answering."""

    #: The input image tensor in the shape (C, H, W)
    image: torch.Tensor
    #: The context/question for the image
    context: str

    #: The possible answers. Not set for testing.
    answers: Optional[List[str]] = None
    #: The weights of the possible answers. Optionally available.
    answer_weights: Optional[torch.Tensor] = None

Do you want to set a simple field_map[Y] (or write your own sample_loader [n])? [Y/n]: n

Created /workspace/dataset/wds/.nv-meta/sample_loader.py. Please edit it to return the proper values.
Done

生成された .nv-meta/sample_loader.py を以下のように書き換えます。データセットによってローダーの処理は異なる可能性があることに注意してください。

# This file was automatically generated by `energon prepare`.
# TODO: Edit it to return the proper fields
# import torch

def sample_loader(raw: dict) -> dict:    # Note: Images are already decoded to tensors
    # TODO: Set the correct values for all (required) fields
    context = [entry['value'] for entry in raw["json"] if entry['from'] == 'human']
    answers = [entry['value'] for entry in raw["json"] if entry['from'] == 'gpt']
    return dict(
        image=raw["jpg"],  # expected type: torch.Tensor
        context=context,  # expected type: str
        answers=answers,  # expected type: typing.Optional[typing.List[str]], default: None
        answer_weights=None,  # expected type: typing.Optional[torch.Tensor], default: None
    )

def part_filter(part: str) -> bool:
    # TODO: Filter for parts required by the sample_loader
    # E.g. if your dataset contains jpeg, txt and json, but you won't use json,
    # remove it from the list, such that it is not decoded. If you need all, keep as is
    return part in ('jpg', 'json')

学習は以下のスクリプトで実行できます。recipe.data の部分が主な変更点になります。加えてこのスクリプトには image_processor の設定を修正する PR#13618 が反映されていないため、既存のクラスをオーバーライドして対応しています。

import logging
from typing import Any, Dict, List, Optional, Sequence

import nemo_run as run
import torch
from megatron.energon import VQASample
from nemo import lightning as nl
from nemo.collections import vlm
from nemo.collections.llm.recipes.optim.adam import (
    distributed_fused_adam_with_cosine_annealing,
)
from nemo.collections.vlm.neva.data.config import DataConfig, ImageDataConfig
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.vlm.mllama.data.task_encoder import LlamaTaskEncoder
from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule
from nemo.collections.multimodal.data.energon.conversation import MLlamaTemplateConfig
from nemo.collections.vlm.mllama.data.sample_encoder import Llama3SampleEncoder
from transformers import AutoProcessor
from dataclasses import dataclass, field


# Load the processor
processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct")

# Paths and configuration
data_path = "/workspace/dataset/wds"
image_processor = processor.image_processor
tokenizer = processor.tokenizer

# Define multimodal sample configuration
multimodal_sample_config = MultiModalSampleConfig(conversation_template_config=MLlamaTemplateConfig(stop_string="<|eot_id|>"))


class CustomLlama3SampleEncoder(Llama3SampleEncoder):
    """MLlama Sample Encoder"""

    def __init__(
        self,
        tokenizer,
        image_processor,
        multimodal_sample_config=MultiModalSampleConfig(),
    ):
        """
        Initialize the VQASampleEncoder.

        Parameters:
        tokenizer (Tokenizer): The HF tokenizer used for processing text.
        image_processor (ImageProcessor): The HF image processor used for preprocessing images.
        multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
            Defaults to MultiModalSampleConfig().
        """
        super().__init__(tokenizer, image_processor, multimodal_sample_config)
        self.conversation_template_config = (
            multimodal_sample_config.conversation_template_config
        )

    def process_image(self, image) -> Dict[str, torch.Tensor]:
        image_dict = self.image_processor.preprocess(image, return_tensors="pt")
        return image_dict


# Initialize the mllama task encoder
task_encoder = LlamaTaskEncoder(
    tokenizer=tokenizer,
    image_processor=image_processor,
    multimodal_sample_config=multimodal_sample_config,
)


task_encoder.encoders = {
    VQASample.__name__: CustomLlama3SampleEncoder(
        tokenizer, image_processor, multimodal_sample_config
    )
}


def configure_recipe(nodes: int = 1, gpus_per_node: int = 8):
    recipe = vlm.mllama_11b.finetune_recipe(
        name="mllama_11b_finetune_energon",
        dir="/workspace/results",
        num_nodes=nodes,
        num_gpus_per_node=gpus_per_node,
        peft_scheme="none",  # 'lora', 'none'
        )

    recipe.resume.restore_config = run.Config(
        nl.RestoreConfig,
        path="/workspace/models/Llama-3.2-11B-Vision-Instruct",
        )

    recipe.trainer.max_steps = 1143
    recipe.trainer.log_every_n_steps = 1
    recipe.trainer.limit_val_batches = 1.0
    recipe.trainer.val_check_interval = 1143
    recipe.log.use_datetime_version = False
    recipe.log.ckpt.train_time_interval = None
    recipe.optim = distributed_fused_adam_with_cosine_annealing(
        max_lr=2.0e-5,
        min_lr=2.0e-07,
        warmup_steps=100,
    )
    # Create the data module
    recipe.data = EnergonMultiModalDataModule(
        path=data_path,
        tokenizer=tokenizer,
        image_processor=image_processor,
        seq_length=6404,
        decoder_seq_length=2048,
        num_workers=8,
        micro_batch_size=1,
        global_batch_size=128,
        multimodal_sample_config=multimodal_sample_config,
        task_encoder=task_encoder,
    )

    return recipe


if __name__ == "__main__":
    print(configure_recipe())
    run.run(configure_recipe(), direct=True)

(参考): NeMo AutoModel

NeMo 2.0 で LLM や VLM のファインチューニングが実行可能なモデルには大きく 2 つのタイプがあります。1 つは Megatron 形式のモデルでこれは本チュートリアルで使用したものになります。Megatron 形式は幅広い並列化手法がサポートされており、高い計算パフォーマンスが期待できます。Megatron 形式でサポートされている VLM はこちらを参照してください。もう 1 つが NeMo AutoModel です。このチュートリアルでは扱いませんが、NeMo AutoModel は HuggingFace のモデルとシームレスに統合されており、開発者は AutoModel を活用することで HuggingFace の幅広いモデルを NeMo でカスタマイズすることが可能になります。Hugging Face の VLM モデルのほとんどは NeMo AutoModel と互換性がありますが、こちらにあるモデル、データセットの組み合わせは収束を確認済みです。

図 3. NeMo Framework のトレーニング ワークフロー

NeMo フォーマットから HuggigFace フォーマットへの変換

NeMo でカスタマイズしたモデルは HuggingFace フォーマットへ変換することが可能です。25.04 コンテナーには、Llama-3.2-Vision の HF へのエクスポート機能が追加されていないため、PR#13346 を反映させます。

cd /opt/NeMo
curl -L https://github.com/NVIDIA/NeMo/pull/13346.diff | git apply
cd /workspace

次に以下のスクリプトを実行します。path は手元のディレクトリの path に変更してください。

from pathlib import Path
from nemo.collections.llm import export_ckpt

if __name__ == "__main__":
    export_ckpt(
        path=Path(
            "/workspace/results/mllama_11b_finetune/checkpoints/model_name=0--val_loss=1.00-step=1159-consumed_samples=148480.0"
            ),
        target="hf",
        output_path=Path("/workspace/results/Llama-3.2-11B-Vision-SFT"),
        overwrite=True,
    )

✓ Checkpoint exported to /workspace/results/Llama-3.2-11B-Vision-SFT とログが表示されると変換は完了です。export されたファイルの中には、HuggingFace でのモデル読み込みに必要なchat_template.jsonpreprocessor_config.json は含まれていないため、別途、meta-llama/Llama-3.2-11B-Vision-Instruct からダウンロードする必要があります。

VLM の推論

推論には vLLM を使用しました。標準コマンドで動作すると meta-llama/Llama-3.2-11B-Vision-Instruct  は コンテキスト長が 128k と長いため DGX H100 や DGX A100 で GPU を 8 枚使用しても  Memory out of error が発生します。max-num-seqs がデフォルトで 256 になっているのでこちらを下げることで対応します。今回は 16 に設定しました。

コンテキスト長を調整する `max-model-len` を調整することで使用メモリを抑えられますが、画像埋め込みトークンを処理しきれないケースが出てくるため、こちらはデフォルト値を採用しています。

下記のようなコマンドで実行しました。vllm のバージョンは 0.8.4 を使用しました。

vllm serve "モデル path" --tensor-parallel-size {並列数} --max-num-seqs 16

VLM の評価

llm-jp-eval-mm (2025 年 5 月 11 日時点のマスター ブランチのコード) を評価に使用させていただきました。多様なタスクが評価でき、代表的なモデルで評価されているのでこちらを使用させていただきました。

llm-jp-eval-mm はブログ執筆時点では vllm に対応していませんでした。vllm は OpenAI の API に準拠しているので llm-jp-eval-mm にすでに実装されている gpt4o.py を元に簡単な修正をして動作させました。

推論で設定したパラメーターは下記です。llm-jp-eval-mm で設定されている標準パラメーターを使用しました。

  • max_new_tokens: 1024
  • temperature: 0.0
  • top_p: 1.0
  • num_beams: 1
  • do_sample: False
  • use_cache: True

評価に使用したタスクは下記になります。

日本語のタスク

英語のタスク

評価結果は下記のようになります。Judge モデルは gpt-4o-2024-11-2 を使用しました。本来は複数回測定して平均と標準偏差を記載すべきですが、時間の都合上、割愛させて頂きました。大きな性能差が見られないケースは複数回計測していないため、計測誤差の範囲内の可能性があります。

ModelLLAVA/LLMLLAVA/RougeJVB-ItW/LLMJVB-ItW/RougeJDocQA/AccJDocQA/LLMHeron/LLMMECHA/AccJMMMU/AccMMMU/AccJIC/AccVG-VQA/LLMVG-VQA/Rouge
Llama-3.2-11B-Vision-Instruct3.6228.313.4430.3114.812.4441.249.23334.678.93.313.7
Llama-3.2-11B-Vision-Instruct-SFT3.734.93.638.217.62.557.218.517.9533.8973.23.714.6
Llama-3.2-11B-Vision-Instruct-SFT-Energon3.7323.643.616.92.560.217.8818.736.972.883.814.9
表 1. llm-jp-eval-mm リーダーボードの値

llm-jp/llava-instruct-ja データは COCO 日常写真 × GPT – 生成説明 を使用しているため Japanese Heron Bench や JA-VLM-Bench-In-the-Wild のテキスト生成系のタスクでは性能を上げています。逆に MECHA/JMMMU のような選択式のタスクのスコアが下がっている可能性があります。

英語の MMMU では性能が落ちていないので、日本語のデータセットによる学習によって英語の読解能力は落ちていないように見えます。

特徴的なケース

Japanese Heron Bench では精度向上しました。下記のようなケースで画像の内容を把握しないと答えられない問題に適切に答えてます。SFT する前では質問に答えられず余分な情報を追加しています。

図 3.1. Japanese Heron Bench の評価サンプル画像 (左から index、画像、質問、Reference の回答、SFT 前のモデルの回答、評価スコア、SFT 後のモデルの回答、評価スコア)

JA-VLM-Bench-In-the-Wild でも精度向上しました。SFT をしていないケースでは同じような文字列を繰り返し正しい回答をしていないことが見受けられます。

図 3.2. JA-VLM-Bench-In-the-Wild の評価サンプル画像 (左から index、画像、質問、Reference の回答、SFT 前のモデルの回答、GPT4o の評価スコア、Rouge スコア、SFT 後のモデルの回答、GPT4o の評価スコア、Rouge スコア)

MECHA-ja の精度は向上しませんでした。これは学習に使用した llm-jp/llava-instruct-ja のデータセットが質問と答えの文章のセットであり、答えに選択肢形式が無いので SFT した際に破壊的忘却が発生した可能性があります。

下記が回答の一例です。回答は合っていますが質問の指示を無視しているため、正解とは判定されていません。

図 3.3. MECHA-ja の評価サンプル画像 (左から index、画像、質問、Reference の回答、SFT 前のモデルの回答、評価スコア、SFT 後のモデルの回答、評価スコア)

まとめ

本記事では、NeMo 2.0 を使用した VLM のファインチューニングから推論、評価までの流れを紹介しました。NeMo Framework を使用して、日本語やドメインに特化した VLM の開発が加速すると嬉しいです。


関連情報

Tags