summaryrefslogtreecommitdiff
path: root/recital_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'recital_dataset.py')
-rw-r--r--recital_dataset.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/recital_dataset.py b/recital_dataset.py
new file mode 100644
index 0000000..73a9aca
--- /dev/null
+++ b/recital_dataset.py
@@ -0,0 +1,56 @@
+import pandas as pd
+import random
+import json
+
+# 读取CSV文件
+df = pd.read_csv('recitalMachine_dataset.csv') # 使用制表符分隔
+df['length_of_answer'] = df['answer'].apply(lambda x: len(str(x).split(',')))
+
+# 按difficulty分组并创建子DataFrame列表
+sub_dataframes = [group for _, group in df.groupby('difficulty')]
+
+# 打印每个difficulty的子DataFrame(可选)
+for i, sub_df in enumerate(sub_dataframes):
+ print(f"Difficulty {sub_df['difficulty'].iloc[0]} 的子DataFrame:")
+ print(sub_df)
+ print("-" * 50)
+
+# 创建以difficulty为键的字典
+sub_dataframes_dict = {diff: group for diff, group in df.groupby('difficulty')}
+
+# 访问特定难度的子DataFrame(例如难度4)
+difficulty_4_df = sub_dataframes_dict[4]
+print(difficulty_4_df)
+
+# 创建生成器函数
+def difficulty4_generator(df):
+ indices = list(df.index)
+ random.shuffle(indices)
+
+ while True:
+ if not indices:
+ # 如果所有记录都已使用,重新洗牌
+ indices = list(df.index)
+ random.shuffle(indices)
+
+ yield df.loc[indices.pop()]
+
+# 创建生成器
+difficulty4_gen = difficulty4_generator(difficulty_4_df)
+
+# 使用示例
+print("从难度4中随机抽取记录:")
+for i in range(min(10, len(difficulty_4_df) * 2)): # 演示循环抽取
+ row = next(difficulty4_gen)
+ print(f"{i+1}. {row['question']} -> {row['answer']}")
+
+json_data = {}
+for difficulty, group in df.groupby('difficulty'):
+ # 转换为字典列表,确保所有数据都可序列化
+ json_data[str(difficulty)] = group.to_dict('records')
+
+# 保存为JSON
+with open('recitalMachine_dataset.json', 'w', encoding='utf-8') as f:
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
+
+print("文件已保存!") \ No newline at end of file