-
Notifications
You must be signed in to change notification settings - Fork 134
Expand file tree
/
Copy pathgemma4_31b.yaml
More file actions
114 lines (94 loc) · 3.08 KB
/
gemma4_31b.yaml
File metadata and controls
114 lines (94 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Configuration for fine-tuning Gemma 4 31B (dense) with MedPix dataset for image description
# Requires 8 GPUs (FSDP2 with activation checkpointing)
# torchrun --nproc-per-node=8 examples/vlm_finetune/finetune.py -c examples/vlm_finetune/gemma4/gemma4_31b.yaml
recipe: FinetuneRecipeForVLM
step_scheduler:
global_batch_size: 8
local_batch_size: 1
ckpt_every_steps: 500
val_every_steps: 500
num_epochs: 2
dist_env:
backend: nccl
timeout_minutes: 60
rng:
_target_: nemo_automodel.components.training.rng.StatefulRNG
seed: 42
ranked: true
model:
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
pretrained_model_name_or_path: google/gemma-4-31B-it
torch_dtype: torch.bfloat16
use_liger_kernel: true
use_sdpa_patching: false
attn_implementation: eager
# 31B does not using kv_shared layers (only used in 2B, 4B), hence use_cache: false.
text_config:
use_cache: false
processor:
padding_side: right
checkpoint:
enabled: true
checkpoint_dir: vlm_checkpoints/gemma4_31b_it/
model_save_format: torch_save
save_consolidated: false
distributed:
strategy: fsdp2
dp_size: none
tp_size: 1
cp_size: 1
sequence_parallel: false
# Activation checkpointing is required for Gemma4 31B to fit in memory, fsdp alone leads to OOM.
activation_checkpointing: true
loss_fn:
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
dataset:
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
path_or_dataset: mmoukouba/MedPix-VQA
split: train[:1000]
dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
num_workers: 0
pin_memory: true
collate_fn:
_target_: nemo_automodel.components.datasets.vlm.collate_fns.gemma4_prefix_collate_fn
validation_dataset:
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
path_or_dataset: mmoukouba/MedPix-VQA
split: validation[:500]
validation_dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn:
_target_: nemo_automodel.components.datasets.vlm.collate_fns.gemma4_prefix_collate_fn
optimizer:
_target_: torch.optim.AdamW
lr: 1e-5
weight_decay: 0.01
betas: [0.9, 0.95]
lr_scheduler:
lr_decay_style: cosine
freeze_config:
freeze_embeddings: true
freeze_vision_tower: true
freeze_audio_tower: true
freeze_language_model: false
# wandb:
# project: <your-project>
# entity: <your-entity>
# name: <your-run-name>
ci:
recipe_owner: athitten
time: "00:20:00"