summaryrefslogtreecommitdiff
path: root/recital_dataset.py
blob: a142c92f00d4abe87e6e016cbc45f044dff09464 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pandas as pd
import random
import json

# 读取CSV文件
df = pd.read_csv('./recitalMachine/recitalMachine_dataset.csv')  # 使用制表符分隔
df['length_of_answer'] = df['answer'].apply(lambda x: len(str(x).split(',')))

# 按difficulty分组并创建子DataFrame列表
sub_dataframes = [group for _, group in df.groupby('difficulty')]

# 打印每个difficulty的子DataFrame(可选)
for i, sub_df in enumerate(sub_dataframes):
    print(f"Difficulty {sub_df['difficulty'].iloc[0]} 的子DataFrame:")
    print(sub_df)
    print("-" * 50)

# 创建以difficulty为键的字典
sub_dataframes_dict = {diff: group for diff, group in df.groupby('difficulty')}

# 访问特定难度的子DataFrame(例如难度4)
difficulty_4_df = sub_dataframes_dict[4]
print(difficulty_4_df)

# 创建生成器函数
def difficulty4_generator(df):
    indices = list(df.index)
    random.shuffle(indices)
    
    while True:
        if not indices:
            # 如果所有记录都已使用,重新洗牌
            indices = list(df.index)
            random.shuffle(indices)
        
        yield df.loc[indices.pop()]

# 创建生成器
difficulty4_gen = difficulty4_generator(difficulty_4_df)

# 使用示例
print("从难度4中随机抽取记录:")
for i in range(min(10, len(difficulty_4_df) * 2)):  # 演示循环抽取
    row = next(difficulty4_gen)
    print(f"{i+1}. {row['question']} -> {row['answer']}")

json_data = {}
for difficulty, group in df.groupby('difficulty'):
    # 转换为字典列表,确保所有数据都可序列化
    json_data[str(difficulty)] = group.to_dict('records')

# 保存为JSON
with open('recitalMachine_dataset.json', 'w', encoding='utf-8') as f:
    json.dump(json_data, f, ensure_ascii=False, indent=2)

print("文件已保存!")