Create early_stop.py
This commit is contained in:
parent
fe2d947179
commit
08b3851d3c
26
2024.06.04_early_stop/early_stop.py
Normal file
26
2024.06.04_early_stop/early_stop.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
This code is supported by the website: https://www.guanjihuan.com
|
||||||
|
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41201
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_break_signal_from_loss_array(loss_array, patience=10, min_delta=0.001):
|
||||||
|
break_signal = 0
|
||||||
|
counter = 0
|
||||||
|
num = len(loss_array)
|
||||||
|
for i0 in range(num):
|
||||||
|
if i0 != 0:
|
||||||
|
if abs(loss_array[i0]-loss_array[i0-1])<min_delta:
|
||||||
|
counter += 1
|
||||||
|
if counter >= patience: # 当损失函数的变化绝对值小于 min_delta 的次数超过 patience 次后,给一个停止信号
|
||||||
|
break_signal = 1
|
||||||
|
print(counter) # 查看满足条件的次数
|
||||||
|
return break_signal
|
||||||
|
|
||||||
|
train_times = 100
|
||||||
|
for i0 in range(train_times):
|
||||||
|
print('Training...')
|
||||||
|
loss_array = [10, 3, 1, 0.1, 0.02, 0.003, 0.001, 0.0004, 0.0005, 0.0001, 0.0003]
|
||||||
|
break_signal = get_break_signal_from_loss_array(loss_array, patience=4, min_delta=0.001)
|
||||||
|
if break_signal == 1:
|
||||||
|
break
|
||||||
|
print('Early stop:', break_signal)
|
Loading…
x
Reference in New Issue
Block a user