1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ import re
8
+ from thefuzz import process
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers .trainer_utils import set_seed
12
+
13
+ '''
14
+ wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
15
+ mkdir data/ceval
16
+ mv ceval-exam.zip data/ceval
17
+ cd data/ceval; unzip ceval-exam.zip
18
+ cd ../../
19
+
20
+ pip install thefuzz
21
+ python eval/evaluate_chat_ceval.py -d data/ceval
22
+ '''
23
+
24
+ def load_models_tokenizer (args ):
25
+ from transformers import AutoModelForCausalLM , AutoTokenizer
26
+ from transformers .generation import GenerationConfig
27
+
28
+ tokenizer = AutoTokenizer .from_pretrained (args .checkpoint_path , trust_remote_code = True )
29
+ model = AutoModelForCausalLM .from_pretrained (args .checkpoint_path , device_map = "auto" , trust_remote_code = True , bf16 = True , use_flash_attn = True ).eval ()
30
+ model .generation_config = GenerationConfig .from_pretrained (args .checkpoint_path , trust_remote_code = True )
31
+ model .generation_config .do_sample = False # use greedy decoding
32
+ return model , tokenizer
33
+
34
+ def process_before_extraction (gen , question , choice_dict ):
35
+ # Example Prompt:
36
+ # 关于传输层的面向连接服务的特性是____。
37
+ # A. 既不保证可靠,也不保证按序交付
38
+ # B. 不保证可靠,但保证按序交付
39
+ # C. 保证可靠,但不保证按序交付
40
+ # D. 既保证可靠,也保证按序交付
41
+ # Example Model Output:
42
+ # 关于传输层的面向连接服务的特性是既保证可靠,也保证按序交付
43
+ # Processed Output:
44
+ # 答案是D
45
+
46
+ question_split = question .rstrip ("。" ).split ("。" )[- 1 ].split ("_" )
47
+
48
+ # replacing the question
49
+ if len (question_split [0 ].strip ()) > 4 :
50
+ gen = gen .replace (question_split [0 ], "答案是" )
51
+ if len (question_split [- 1 ].strip ()) > 4 :
52
+ gen = gen .replace (question_split [- 1 ], "" )
53
+
54
+ # replace the choice by letter in the generated sentence
55
+ # from longest one to shortest one
56
+ for key , val in sorted (choice_dict .items (), key = lambda x : len (x [1 ]), reverse = True ):
57
+ gen = gen .replace (val .rstrip ("。" ), key )
58
+ return gen
59
+
60
+ def count_substr (gen , pattern ):
61
+ return len (re .findall (pattern , gen ))
62
+
63
+ def extract_choice (gen , prompt , choice_list ):
64
+ # 答案是A | 选项是A | 应该选A选项
65
+ res = re .search (r"(?:(?:选|选择|选定)|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|为|:|:|】))[^ABCD]{0,10}?(?:是|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$)" , gen )
66
+
67
+ # A选项正确 | A选项符合题意
68
+ if res is None :
69
+ res = re .search (r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对|符合))[^ABCD]{0,4}?(?:正确|对|符合)" , gen )
70
+
71
+ # 直接输出 A
72
+ if res is None :
73
+ res = re .search (r"^(A|B|C|D)(?:。|\.|,|,|.|$)" , gen )
74
+
75
+ # 获取第一个出现的字母
76
+ if res is None :
77
+ res = re .search (r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])" , gen )
78
+
79
+ if res is None :
80
+ return choices [choice_list .index (process .extractOne (gen , choice_list )[0 ])]
81
+ else :
82
+ return res .group (1 )
83
+
84
+ def format_example (line ):
85
+ example = line ['question' ] + "\n \n "
86
+ for choice in choices :
87
+ example += f'{ choice } . { line [f"{ choice } " ]} \n '
88
+ return example
89
+
90
+ def extract_answer (response , row ):
91
+ prompt = row ['question' ]
92
+ gen = process_before_extraction (response , prompt , {choice : row [choice ] for choice in choices })
93
+ if not isinstance (prompt , str ):
94
+ prompt = prompt [0 ]
95
+ pred = extract_choice (gen , prompt , [row [choice ] for choice in choices ])
96
+ return pred
97
+
98
+ @torch .no_grad ()
99
+ def eval_subject (
100
+ model ,
101
+ tokenizer ,
102
+ subject_name ,
103
+ test_df ,
104
+ save_result_dir = None ,
105
+ overwrite = False ,
106
+ ** kwargs
107
+ ):
108
+
109
+ result_path = os .path .join (save_result_dir , f'{ subject_name } _result.csv' )
110
+ if not overwrite and os .path .exists (result_path ):
111
+ print (f"{ result_path } existed, skip!" )
112
+ score = []
113
+ for (_ , datarow ), (_ , resultrow ) in zip (test_df .iterrows (), pd .read_csv (result_path ).iterrows ()):
114
+ pred = extract_answer (resultrow ['model_response' ], datarow )
115
+ correct = 1 if pred == datarow ['answer' ] else 0
116
+ score .append (correct )
117
+ correct_ratio = 100 * sum (score ) / len (score )
118
+ return correct_ratio
119
+
120
+ responses = []
121
+ result = []
122
+ score = []
123
+
124
+ for _ , row in tqdm (test_df .iterrows (), total = len (test_df )):
125
+ question = format_example (row )
126
+
127
+ response , history = model .chat (
128
+ tokenizer ,
129
+ question ,
130
+ history = None ,
131
+ )
132
+ print (question )
133
+ print (response )
134
+ pred = extract_answer (response , row )
135
+ print (pred )
136
+ print ("======================" )
137
+
138
+ if 'answer' in row :
139
+ correct = 1 if pred == row ['answer' ] else 0
140
+ score .append (correct )
141
+ if args .debug : print (f'{ question } pred: { pred } ref: { row ["answer" ]} ' )
142
+ responses .append (response )
143
+ result .append (pred )
144
+
145
+ if score :
146
+ correct_ratio = 100 * sum (score ) / len (score )
147
+ if args .debug : print (subject_name , correct_ratio )
148
+ else :
149
+ correct_ratio = 0
150
+ if save_result_dir :
151
+ test_df ['model_response' ] = responses
152
+ test_df ['model_output' ] = result
153
+ if score :
154
+ test_df ["correctness" ] = score
155
+ os .makedirs (save_result_dir , exist_ok = True )
156
+ test_df .to_csv (result_path , encoding = "utf-8" , index = False )
157
+
158
+ return correct_ratio
159
+
160
+
161
+ def cal_ceval (res ):
162
+ acc_sum_dict = dict ()
163
+ acc_norm_sum_dict = dict ()
164
+ cnt_dict = dict ()
165
+ acc_sum = 0.
166
+ cnt = 0
167
+ hard_cnt = 0
168
+ hard_acc_sum = 0.
169
+ for tt in res .keys ():
170
+ name = tt .split ('-' )[- 1 ]
171
+ acc_sum += float (res [tt ])
172
+ cnt += 1
173
+ class_ = TASK_NAME_MAPPING [name ][2 ]
174
+ if class_ not in acc_sum_dict :
175
+ acc_sum_dict [class_ ] = 0.
176
+ acc_norm_sum_dict [class_ ] = 0.
177
+ cnt_dict [class_ ] = 0.
178
+ if name in hard_list :
179
+ hard_cnt += 1
180
+ hard_acc_sum += float (res [tt ])
181
+ acc_sum_dict [class_ ] += float (res [tt ])
182
+ cnt_dict [class_ ] += 1
183
+ print ('\n \n \n ' )
184
+ for k in ['STEM' , 'Social Science' , 'Humanities' , 'Other' ]:
185
+ if k in cnt_dict :
186
+ print ('%s acc: %.2f ' % (
187
+ k , acc_sum_dict [k ] / cnt_dict [k ]))
188
+ if hard_cnt > 0 :
189
+ print ('Hard acc:%.2f ' % (hard_acc_sum / hard_cnt ))
190
+ print ('AVERAGE acc:%.2f ' % (acc_sum / cnt ))
191
+
192
+
193
+ TASK_NAME_MAPPING = {
194
+ "computer_network" : ["Computer Network" , "\u8ba1 \u7b97 \u673a \u7f51 \u7edc " , "STEM" ],
195
+ "operating_system" : ["Operating System" , "\u64cd \u4f5c \u7cfb \u7edf " , "STEM" ],
196
+ "computer_architecture" : ["Computer Architecture" , "\u8ba1 \u7b97 \u673a \u7ec4 \u6210 " , "STEM" ],
197
+ "college_programming" : ["College Programming" , "\u5927 \u5b66 \u7f16 \u7a0b " , "STEM" ],
198
+ "college_physics" : ["College Physics" , "\u5927 \u5b66 \u7269 \u7406 " , "STEM" ],
199
+ "college_chemistry" : ["College Chemistry" , "\u5927 \u5b66 \u5316 \u5b66 " , "STEM" ],
200
+ "advanced_mathematics" : ["Advanced Mathematics" , "\u9ad8 \u7b49 \u6570 \u5b66 " , "STEM" ],
201
+ "probability_and_statistics" : ["Probability and Statistics" , "\u6982 \u7387 \u7edf \u8ba1 " , "STEM" ],
202
+ "discrete_mathematics" : ["Discrete Mathematics" , "\u79bb \u6563 \u6570 \u5b66 " , "STEM" ],
203
+ "electrical_engineer" : ["Electrical Engineer" , "\u6ce8 \u518c \u7535 \u6c14 \u5de5 \u7a0b \u5e08 " , "STEM" ],
204
+ "metrology_engineer" : ["Metrology Engineer" , "\u6ce8 \u518c \u8ba1 \u91cf \u5e08 " , "STEM" ],
205
+ "high_school_mathematics" : ["High School Mathematics" , "\u9ad8 \u4e2d \u6570 \u5b66 " , "STEM" ],
206
+ "high_school_physics" : ["High School Physics" , "\u9ad8 \u4e2d \u7269 \u7406 " , "STEM" ],
207
+ "high_school_chemistry" : ["High School Chemistry" , "\u9ad8 \u4e2d \u5316 \u5b66 " , "STEM" ],
208
+ "high_school_biology" : ["High School Biology" , "\u9ad8 \u4e2d \u751f \u7269 " , "STEM" ],
209
+ "middle_school_mathematics" : ["Middle School Mathematics" , "\u521d \u4e2d \u6570 \u5b66 " , "STEM" ],
210
+ "middle_school_biology" : ["Middle School Biology" , "\u521d \u4e2d \u751f \u7269 " , "STEM" ],
211
+ "middle_school_physics" : ["Middle School Physics" , "\u521d \u4e2d \u7269 \u7406 " , "STEM" ],
212
+ "middle_school_chemistry" : ["Middle School Chemistry" , "\u521d \u4e2d \u5316 \u5b66 " , "STEM" ],
213
+ "veterinary_medicine" : ["Veterinary Medicine" , "\u517d \u533b \u5b66 " , "STEM" ],
214
+ "college_economics" : ["College Economics" , "\u5927 \u5b66 \u7ecf \u6d4e \u5b66 " , "Social Science" ],
215
+ "business_administration" : ["Business Administration" , "\u5de5 \u5546 \u7ba1 \u7406 " , "Social Science" ],
216
+ "marxism" : ["Marxism" , "\u9a6c \u514b \u601d \u4e3b \u4e49 \u57fa \u672c \u539f \u7406 " , "Social Science" ],
217
+ "mao_zedong_thought" : ["Mao Zedong Thought" , "\u6bdb \u6cfd \u4e1c \u601d \u60f3 \u548c \u4e2d \u56fd \u7279 \u8272 \u793e \u4f1a \u4e3b \u4e49 \u7406 \u8bba \u4f53 \u7cfb \u6982 \u8bba " , "Social Science" ],
218
+ "education_science" : ["Education Science" , "\u6559 \u80b2 \u5b66 " , "Social Science" ],
219
+ "teacher_qualification" : ["Teacher Qualification" , "\u6559 \u5e08 \u8d44 \u683c " , "Social Science" ],
220
+ "high_school_politics" : ["High School Politics" , "\u9ad8 \u4e2d \u653f \u6cbb " , "Social Science" ],
221
+ "high_school_geography" : ["High School Geography" , "\u9ad8 \u4e2d \u5730 \u7406 " , "Social Science" ],
222
+ "middle_school_politics" : ["Middle School Politics" , "\u521d \u4e2d \u653f \u6cbb " , "Social Science" ],
223
+ "middle_school_geography" : ["Middle School Geography" , "\u521d \u4e2d \u5730 \u7406 " , "Social Science" ],
224
+ "modern_chinese_history" : ["Modern Chinese History" , "\u8fd1 \u4ee3 \u53f2 \u7eb2 \u8981 " , "Humanities" ],
225
+ "ideological_and_moral_cultivation" : ["Ideological and Moral Cultivation" , "\u601d \u60f3 \u9053 \u5fb7 \u4fee \u517b \u4e0e \u6cd5 \u5f8b \u57fa \u7840 " , "Humanities" ],
226
+ "logic" : ["Logic" , "\u903b \u8f91 \u5b66 " , "Humanities" ],
227
+ "law" : ["Law" , "\u6cd5 \u5b66 " , "Humanities" ],
228
+ "chinese_language_and_literature" : ["Chinese Language and Literature" , "\u4e2d \u56fd \u8bed \u8a00 \u6587 \u5b66 " , "Humanities" ],
229
+ "art_studies" : ["Art Studies" , "\u827a \u672f \u5b66 " , "Humanities" ],
230
+ "professional_tour_guide" : ["Professional Tour Guide" , "\u5bfc \u6e38 \u8d44 \u683c " , "Humanities" ],
231
+ "legal_professional" : ["Legal Professional" , "\u6cd5 \u5f8b \u804c \u4e1a \u8d44 \u683c " , "Humanities" ],
232
+ "high_school_chinese" : ["High School Chinese" , "\u9ad8 \u4e2d \u8bed \u6587 " , "Humanities" ],
233
+ "high_school_history" : ["High School History" , "\u9ad8 \u4e2d \u5386 \u53f2 " , "Humanities" ],
234
+ "middle_school_history" : ["Middle School History" , "\u521d \u4e2d \u5386 \u53f2 " , "Humanities" ],
235
+ "civil_servant" : ["Civil Servant" , "\u516c \u52a1 \u5458 " , "Other" ],
236
+ "sports_science" : ["Sports Science" , "\u4f53 \u80b2 \u5b66 " , "Other" ],
237
+ "plant_protection" : ["Plant Protection" , "\u690d \u7269 \u4fdd \u62a4 " , "Other" ],
238
+ "basic_medicine" : ["Basic Medicine" , "\u57fa \u7840 \u533b \u5b66 " , "Other" ],
239
+ "clinical_medicine" : ["Clinical Medicine" , "\u4e34 \u5e8a \u533b \u5b66 " , "Other" ],
240
+ "urban_and_rural_planner" : ["Urban and Rural Planner" , "\u6ce8 \u518c \u57ce \u4e61 \u89c4 \u5212 \u5e08 " , "Other" ],
241
+ "accountant" : ["Accountant" , "\u6ce8 \u518c \u4f1a \u8ba1 \u5e08 " , "Other" ],
242
+ "fire_engineer" : ["Fire Engineer" , "\u6ce8 \u518c \u6d88 \u9632 \u5de5 \u7a0b \u5e08 " , "Other" ],
243
+ "environmental_impact_assessment_engineer" : ["Environmental Impact Assessment Engineer" , "\u73af \u5883 \u5f71 \u54cd \u8bc4 \u4ef7 \u5de5 \u7a0b \u5e08 " , "Other" ],
244
+ "tax_accountant" : ["Tax Accountant" , "\u7a0e \u52a1 \u5e08 " , "Other" ],
245
+ "physician" : ["Physician" , "\u533b \u5e08 \u8d44 \u683c " , "Other" ]
246
+ }
247
+ hard_list = ['advanced_mathematics' , 'discrete_mathematics' , 'probability_and_statistics' , 'college_physics' , 'college_chemistry' , 'high_school_mathematics' , 'high_school_physics' , 'high_school_chemistry' ]
248
+ choices = ["A" , "B" , "C" , "D" ]
249
+
250
+
251
+ def main (args ):
252
+ print ("loading model weights" )
253
+ if args .checkpoint_path :
254
+ model , tokenizer = load_models_tokenizer (args )
255
+ else :
256
+ model , tokenizer = None , None
257
+ print ("model loaded" )
258
+ dev_result = {}
259
+ for subject_name in tqdm (TASK_NAME_MAPPING .keys ()):
260
+ val_file_path = os .path .join (args .eval_data_path , 'val' , f'{ subject_name } _val.csv' )
261
+ # dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
262
+ # test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
263
+ val_df = pd .read_csv (val_file_path )
264
+ # dev_df = pd.read_csv(dev_file_path)
265
+ # test_df = pd.read_csv(test_file_path)
266
+
267
+ score = eval_subject (model , tokenizer , subject_name , val_df ,
268
+ save_result_dir = f"outs_chat/ceval_eval_result" , overwrite = args .overwrite )
269
+ dev_result [subject_name ] = score
270
+ cal_ceval (dev_result )
271
+
272
+
273
+ if __name__ == '__main__' :
274
+ parser = argparse .ArgumentParser (description = 'Test HF checkpoint.' )
275
+ parser .add_argument ('-c' , '--checkpoint-path' , type = str , help = 'Checkpoint path' , default = "Qwen/Qwen-7B-Chat" )
276
+ parser .add_argument ('-s' , '--seed' , type = int , default = 1234 , help = 'Random seed' )
277
+
278
+ """Provide extra arguments required for tasks."""
279
+ group = parser .add_argument_group (title = 'Evaluation options' )
280
+ group .add_argument ('-d' , '--eval_data_path' , type = str , required = True ,
281
+ help = 'Path to eval data' )
282
+ group .add_argument ("--debug" , action = 'store_true' , default = False ,
283
+ help = 'Print infos.' )
284
+ group .add_argument ("--overwrite" , action = 'store_true' , default = False ,
285
+ help = 'Overwrite existed results' )
286
+
287
+ args = parser .parse_args ()
288
+ set_seed (args .seed )
289
+
290
+ main (args )
0 commit comments