@@ -768,7 +768,10 @@ def enforce_terms(terms_accepted, *args):
768768 message = "You must agree to the Terms of Use to proceed."
769769 gr .Info (message )
770770 return message
771- return run_train_script (* args )
771+ try :
772+ return run_train_script (* args )
773+ except Exception as e :
774+ return e
772775
773776 terms_checkbox = gr .Checkbox (
774777 label = i18n ("I agree to the terms of use" ),
@@ -788,32 +791,6 @@ def enforce_terms(terms_accepted, *args):
788791
789792 with gr .Row ():
790793 train_button = gr .Button (i18n ("Start Training" ))
791- train_button .click (
792- fn = enforce_terms ,
793- inputs = [
794- terms_checkbox ,
795- model_name ,
796- save_every_epoch ,
797- save_only_latest ,
798- save_every_weights ,
799- total_epoch ,
800- sampling_rate ,
801- batch_size ,
802- gpu ,
803- overtraining_detector ,
804- overtraining_threshold ,
805- pretrained ,
806- cleanup ,
807- index_algorithm ,
808- cache_dataset_in_gpu ,
809- custom_pretrained ,
810- g_pretrained_path ,
811- d_pretrained_path ,
812- vocoder ,
813- checkpointing ,
814- ],
815- outputs = [train_output_info ],
816- )
817794
818795 stop_train_button = gr .Button (i18n ("Stop Training" ), visible = False )
819796 stop_train_button .click (
@@ -880,31 +857,21 @@ def enforce_terms(terms_accepted, *args):
880857 )
881858
882859 def toggle_visible (checkbox ):
883- return { " visible" : checkbox , "__type__" : "update" }
860+ return gr . update ( visible = checkbox )
884861
885862 def toggle_pretrained (pretrained , custom_pretrained ):
886- if custom_pretrained == False :
887- return {"visible" : pretrained , "__type__" : "update" }, {
888- "visible" : False ,
889- "__type__" : "update" ,
890- }
863+ if not custom_pretrained :
864+ return gr .update (visible = pretrained ), gr .update (visible = False )
891865 else :
892- return {"visible" : pretrained , "__type__" : "update" }, {
893- "visible" : pretrained ,
894- "__type__" : "update" ,
895- }
866+ return gr .update (visible = pretrained ), gr .update (visible = pretrained )
896867
897- def enable_stop_train_button ():
898- return {"visible" : False , "__type__" : "update" }, {
899- "visible" : True ,
900- "__type__" : "update" ,
901- }
868+ def enable_stop_train_button (terms_accepted ):
869+ if not terms_accepted :
870+ return gr .update (visible = True ), gr .update (visible = False )
871+ return gr .update (visible = False ), gr .update (visible = True )
902872
903873 def disable_stop_train_button ():
904- return {"visible" : True , "__type__" : "update" }, {
905- "visible" : False ,
906- "__type__" : "update" ,
907- }
874+ return gr .update (visible = True ), gr .update (visible = False )
908875
909876 def download_prerequisites ():
910877 gr .Info (
@@ -1030,11 +997,38 @@ def update_slider_visibility(noise_reduction):
1030997 inputs = [overtraining_detector ],
1031998 outputs = [overtraining_settings ],
1032999 )
1000+
10331001 train_button .click (
10341002 fn = enable_stop_train_button ,
1035- inputs = [],
1003+ inputs = [terms_checkbox ],
10361004 outputs = [train_button , stop_train_button ],
1005+ ).then (
1006+ fn = enforce_terms ,
1007+ inputs = [
1008+ terms_checkbox ,
1009+ model_name ,
1010+ save_every_epoch ,
1011+ save_only_latest ,
1012+ save_every_weights ,
1013+ total_epoch ,
1014+ sampling_rate ,
1015+ batch_size ,
1016+ gpu ,
1017+ overtraining_detector ,
1018+ overtraining_threshold ,
1019+ pretrained ,
1020+ cleanup ,
1021+ index_algorithm ,
1022+ cache_dataset_in_gpu ,
1023+ custom_pretrained ,
1024+ g_pretrained_path ,
1025+ d_pretrained_path ,
1026+ vocoder ,
1027+ checkpointing ,
1028+ ],
1029+ outputs = [train_output_info ],
10371030 )
1031+
10381032 train_output_info .change (
10391033 fn = disable_stop_train_button ,
10401034 inputs = [],
0 commit comments