Dynamiczne dostrajanie (DFT) pomostem do generalizacji w uczeniu nadzorowanym LLM
Dostrajanie nadzorowane (SFT, ang. Supervised Fine-Tuning) to standardowa technika adaptacji modeli językowych (LLM) do nowych zadań, poprzez trenowanie ich na zbiorach danych demonstracyjnych. Metoda ta jest ceniona za prostotę i efektywność w rozwijaniu eksperckiej wiedzy w oparciu o zadane przykłady. Jednocześnie SFT często ustępuje jednak uczeniu ze wzmocnieniem (RL) pod względem generalizacji wiedzy. RL pozwala modelom poznawać różnorodne strategie, co prowadzi do lepszej generalizacji, ale wymaga znacznych zasobów obliczeniowych, starannego dostrajania hiperparametrów i dostępu do sygnałów nagrody, co nie zawsze jest praktyczne.
Czy można zatem ulepszyć SFT? To kluczowe pytanie, zwłaszcza gdy zbiory danych nie zawierają negatywnych przykładów lub modele nagród są niedostępne. Dotychczasowe próby zmierzenia się z wyzwaniami SFT i RL doprowadziły do powstania różnorodnych metod hybrydowych. Częstą strategią jest łączenie początkowej fazy SFT z późniejszym udoskonaleniem RL, jak w metodach typu InstructGPT. Alternatywne podejścia, takie jak przeplatanie kroków SFT i RL lub Direct Preference Optimization (DPO), mają na celu bardziej wydajną integrację sygnałów imitacji i wzmocnienia. Techniki takie jak Negative-aware Fine-Tuning (NFT) pozwalają modelom na samodoskonalenie poprzez modelowanie nieprawidłowych wyjść.
Dynamiczne Dostrajanie (DFT)
Zespół naukowców z Southeast University, UC Berkeley, Shanghai Jiao Tong University, Nanyang Technological University i Wuhan University zaproponował Dynamic Fine-Tuning (DFT), metodę mającą na celu rozwiązanie problemu ograniczonej generalizacji modeli SFT LLM. Poprzez analizę matematyczną zidentyfikowali oni, że standardowe gradienty SFT kodują wadliwą strukturę nagród, ograniczając zdolność modelu do skutecznej generalizacji. DFT rozwiązuje ten problem poprzez stabilizację aktualizacji gradientów poprzez dynamiczne przeskalowanie funkcji celu w oparciu o prawdopodobieństwo każdego tokenu. Ta modyfikacja poprawia generalizację w wielu testach porównawczych i modelach bazowych. Co więcej, DFT wykazuje konkurencyjne wyniki w ustawieniach RL offline, oferując prostszą alternatywę dla tradycyjnych metod RL.
DFT został oceniony w standardowym ustawieniu SFT, gdzie dostępne są tylko dane demonstracyjne, bez próbek negatywnych, modeli nagród lub sygnałów weryfikacyjnych. Był trenowany przy użyciu zbioru danych NuminaMath CoT, który zawiera 860 tys. problemów matematycznych i rozwiązań. Zbiór danych obejmuje różne źródła, w tym chińskie zadania matematyczne ze szkół średnich oraz amerykańskie i międzynarodowe olimpiady matematyczne. W ustawieniu RL offline DFT został przetestowany przy użyciu struktury dostrajania próbkowania odrzucania (RFT). W tym przypadku generowane są odpowiedzi na 10 tys. pytań matematycznych, a poprawne odpowiedzi są weryfikowane i zachowywane, co daje 140 tys. przykładów treningowych. Tworzone są również pary preferencji pozytywno-negatywnych dla treningu DPO z wygenerowanej odpowiedzi.
Wyniki eksperymentów
W ustawieniach SFT, DFT przewyższa standardowe SFT we wszystkich ocenianych LLM i wykazuje lepszą generalizację i niezawodność w trudnych testach porównawczych, gdzie standardowe SFT daje minimalny lub negatywny wpływ. Wykazuje lepszą wydajność uczenia się i szybszą charakterystykę zbieżności oraz przewyższa ważone ważnością SFT (iw-SFT) w większości scenariuszy. W ustawieniach RL offline DFT przewyższa zarówno offline, jak i online RL. Uzyskuje średnio 35,43 punktu, przewyższając najlepszą metodę offline, RFT, o +11,46 punktu i przewyższa najsilniejszy algorytm RL online, GRPO, o +3,43 punktu. Co więcej, DFT uzyskuje 64,71 punktu na Math500, nieznacznie wyprzedzając GRPO, i osiąga znaczne zyski w trudniejszych zadaniach, takich jak AMC23 (+7,19 w porównaniu z GRPO) i Minerva Math (+6,23 w porównaniu z GRPO).
Ograniczenia i przyszłe kierunki badań
W tej pracy naukowcy zajmują się luką generalizacji między SFT i RL. Wprowadzają Dynamic Fine-Tuning (DFT), prostą, ale potężną metodę, która dynamicznie zmienia wagę straty SFT przy użyciu prawdopodobieństw tokenów. Ta jednowierszowa modyfikacja stabilizuje naukę i poprawia generalizację, o czym świadczą zyski wydajności w testach porównawczych rozumowania matematycznego. Jednak oceny DFT są ograniczone do zbiorów danych i modeli skoncentrowanych na matematyce do parametrów 7B, bez testowania w innych domenach lub większych modelach. Ponadto badania te ograniczają się do scenariuszy tekstowych. Przyszłe prace mają na celu rozszerzenie DFT na szersze testy porównawcze, większe modele i zadania wizualno-językowe, aby zweryfikować jego skuteczność między modalnościami.
