Fine-tuning Hyperparameters

#27
by tanliboy - opened

What are the optimal hyperparameters for fine-tuning gemma-2-9b?

The evaluation result of gemma-2-9b-it is very promising. However, when I fine-tune the base model using learning rates between 2.0e-5 and 1.0e-4 and a batch size of 16 on the HuggingFaceH4/ultrachat_200k dataset, the resulting SFT fine-tuned model does not perform as well as the fine-tuned models from Llama3-8b and Qwen2-7b. Given that the architecture of gemma-2 is different (with soft capping), what are the recommended hyperparameters for fine-tuning? Is there a fine-tuning recipe I can follow to start with? Any suggestions? Thanks!

tanliboy changed discussion status to closed

Did you find any solution?

I found that the problem was on the wrong implementation of GQA in transformers that was already fixed 1-2 weeks ago.
You should be able to fine-tune a decent model using alignment-handbook recipes. Here is an example run if you are interested (https://huggingface.co./tanliboy/zephyr-gemma-2-9b-dpo).

Ah thx!

Hello, which transformers version included this fix?

@lucasjin It has been fixed a while ago.

tanliboy changed discussion status to open

@lucasjin , you need to use the eager attention implementation instead of flash_attention_2 on the latest released version. Alternatively, you can install from the main branch. For more details, refer to this issue.

Sign up or log in to comment