Skip to content
This repository was archived by the owner on Mar 3, 2024. It is now read-only.

CyberZHG/keras-radam

Folders and files

NameName
Last commit message
Last commit date

Latest commit

e80f88d · Jan 1, 2022

History

38 Commits
Aug 16, 2019
Jan 1, 2022
Jan 1, 2022
Aug 16, 2019
Aug 16, 2019
Aug 16, 2019
Jan 1, 2022
Jan 1, 2022
Aug 16, 2019
Aug 16, 2019
Aug 16, 2019
Sep 23, 2019
Aug 16, 2019

Repository files navigation

Keras RAdam

Version License

[中文|English]

Unofficial implementation of RAdam in Keras.

Install

pip install keras-rectified-adam

External Link

Usage

from tensorflow import keras
import numpy as np
from keras_radam import RAdam

# Build toy model with RAdam optimizer
model = keras.models.Sequential()
model.add(keras.layers.Dense(input_shape=(17,), units=3))
model.compile(RAdam(), loss='mse')

# Generate toy data
x = np.random.standard_normal((4096 * 30, 17))
w = np.random.standard_normal((17, 3))
y = np.dot(x, w)

# Fit
model.fit(x, y, epochs=5)

Use Warmup

from keras_radam import RAdam

RAdam(total_steps=10000, warmup_proportion=0.1, min_lr=1e-5)