diff options
Diffstat (limited to 'recital_dataset.py')
-rw-r--r-- | recital_dataset.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/recital_dataset.py b/recital_dataset.py new file mode 100644 index 0000000..73a9aca --- /dev/null +++ b/recital_dataset.py @@ -0,0 +1,56 @@ +import pandas as pd +import random +import json + +# 读取CSV文件 +df = pd.read_csv('recitalMachine_dataset.csv') # 使用制表符分隔 +df['length_of_answer'] = df['answer'].apply(lambda x: len(str(x).split(','))) + +# 按difficulty分组并创建子DataFrame列表 +sub_dataframes = [group for _, group in df.groupby('difficulty')] + +# 打印每个difficulty的子DataFrame(可选) +for i, sub_df in enumerate(sub_dataframes): + print(f"Difficulty {sub_df['difficulty'].iloc[0]} 的子DataFrame:") + print(sub_df) + print("-" * 50) + +# 创建以difficulty为键的字典 +sub_dataframes_dict = {diff: group for diff, group in df.groupby('difficulty')} + +# 访问特定难度的子DataFrame(例如难度4) +difficulty_4_df = sub_dataframes_dict[4] +print(difficulty_4_df) + +# 创建生成器函数 +def difficulty4_generator(df): + indices = list(df.index) + random.shuffle(indices) + + while True: + if not indices: + # 如果所有记录都已使用,重新洗牌 + indices = list(df.index) + random.shuffle(indices) + + yield df.loc[indices.pop()] + +# 创建生成器 +difficulty4_gen = difficulty4_generator(difficulty_4_df) + +# 使用示例 +print("从难度4中随机抽取记录:") +for i in range(min(10, len(difficulty_4_df) * 2)): # 演示循环抽取 + row = next(difficulty4_gen) + print(f"{i+1}. {row['question']} -> {row['answer']}") + +json_data = {} +for difficulty, group in df.groupby('difficulty'): + # 转换为字典列表,确保所有数据都可序列化 + json_data[str(difficulty)] = group.to_dict('records') + +# 保存为JSON +with open('recitalMachine_dataset.json', 'w', encoding='utf-8') as f: + json.dump(json_data, f, ensure_ascii=False, indent=2) + +print("文件已保存!")
\ No newline at end of file |