summaryrefslogtreecommitdiff
path: root/recitalMachine2.py
blob: 4a328e5957a02825db5f619d998136101a6cfedb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
import gradio as gr
import openai
import time
import json
from datetime import datetime
import os
import random
from collections import defaultdict
from uuid import uuid4
from typing import Dict
import threading


#TODO
#Restart game - Done - Need to reset iterator
#Local knowledge base - Done
#dynamic difficulty - Done
#save log locally - Done
#multi user status


QUESTIONS_DATASET_PATH='./recitalMachine/recitalMachine_dataset.json'
QUESTION_MODE='LOCAL' #LOCAL:本地题库,AI:完全由AI负责出题

class GameStatus:
    def __init__(self):
        # 游戏状态
        self.questionIterator=QuestionIterator()
        self.conversation_history_hidden=[]
        self.correct_count=0
        self.is_correct=False
        self.difficulty_level=1
        self.is_game_over=False
        self.initial_message=[]
        self.next_question=None
        self.conversation_history_hidden.append({"role": "system", "content": GameController.system_prompt})
        GameController.get_initial_chat_display(self)

class GameController:
    model="Moonshot-Kimi-K2-Instruct"
    #model="deepseek-v3"

    client = openai.OpenAI(
        api_key=os.getenv('ALI_BAILIAN_API_KEY'),  
        base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
        )
    
    system_prompt = """
    你是架空朝代的皇帝,文治武功,英明神武,权术高超,文武百官无不敬畏。我是你的儿子,今年十岁。
    今日你来检查我的学业,心情尚可,但要求极为严苛。

    # 角色设定:
    - **身份**:严父与君王的结合体,对继承人期望极高。
    - **语气**:威严、简洁、不容置疑,带有帝王般的压迫感。常用“朕”、“皇子”、“皇儿”称呼。
    - **核心行为**:化身无情的出题机器,持续考察皇子对国学经典的掌握。

    # 出题规则:
    1.  **出题内容**:系统会在提示词中给出“已知上句,回答下句”的题目。你需要用皇帝的语气复述这道题。
    2.  **反馈机制**:
        - 若皇子答对,给予简单的正反馈,然后立即出下一题。保持压力。
        - 若皇子答错,你必须立即予以斥责,并打皇子十下戒尺。你会提示正确答案,然后让皇子再背。如果系统给出了下一题的提示,要等这一题答对之后再问。
    3. **边界条件**:如果皇子的回答十分出格,例如不背书了要去蹴鞠,你要给予额外的严厉惩罚。惩罚完继续出题。
    4.  **终止条件**: 除非皇子主动哭泣、求饶(说出“我错了”、“别打了”、“疼”等类似词),否则绝不停下出题。一旦皇子求饶,你可表现出失望又略带一丝心疼的情绪,并结束考验。结束时总结一共答对了几题,答错了几题,惩罚有哪些,并输出<游戏结束>作为标记。
    """

    initial_message = "(在书房里恭敬地站在你面前) 父皇今日可要考校儿臣功课?"

    @staticmethod
    def get_ai_response(g:GameStatus):
        '''
        Given the current hidden conversation history, get AI response, and append the response in the history
        '''
        response = GameController.client.chat.completions.create(
            model=GameController.model,
            messages=g.conversation_history_hidden,
            temperature=0.7, # 温度不宜过高,保证出题的准确性
            top_p=0.95,
            stream=True,
            #max_tokens=150,
        )
            
        full_response = ""
        for chunk in response:
            if chunk.choices[0].delta.content is not None:
                chunk_content = chunk.choices[0].delta.content
                full_response += chunk_content
                yield chunk_content  # 逐块返回
      
        # 将完整的AI回复加入历史
        g.conversation_history_hidden.append({"role": "assistant", "content": full_response})
        GameController.update_status(g,full_response)  


    @staticmethod
    def chat_with_ai(g:GameStatus,message, chat_history):
        '''
        Take the chat history from gradio, get ai response and return the updated chat history
        @params
        message: the new user input
        chat_history: a list of chat history visible to users
        '''
        # 添加用户消息到历史
        g.conversation_history_hidden.append({"role": "user", "content": message})
        if not g.is_game_over: 
            if g.is_correct:
                GameController.hint_next_question(g)
                g.is_correct = False

        chat_history.append({'role':'user','content':message})
        chat_history.append({'role':'assistant', 'content':""})
        
        # 获取流式响应并逐步更新聊天界面
        full_response = ""
        for chunk in GameController.get_ai_response(g):
            full_response += chunk
            chat_history[-1] = {'role':'assistant', 'content': full_response}
            yield chat_history  # 逐步更新界面
        
        return chat_history

    @staticmethod
    def get_initial_chat_display(g: GameStatus):
        
        g.conversation_history_hidden.append({"role": "user", "content": GameController.initial_message})
        GameController.hint_next_question(g)
        
        # 获取初始响应(非流式)
        response = GameController.client.chat.completions.create(
            model=GameController.model,
            messages=g.conversation_history_hidden,
            temperature=0.7,
            top_p=0.95
        )
        
        ai_response = response.choices[0].message.content.strip()
        g.conversation_history_hidden.append({"role": "assistant", "content": ai_response})
        GameController.update_status(g,ai_response)
        
        g.initial_message.append({'role':'user','content':GameController.initial_message})
        g.initial_message.append({'role':'assistant','content':ai_response})

    @staticmethod
    def update_status(g: GameStatus,ai_response):
        end_mark="游戏结束"
        if not g.is_game_over:
            if end_mark in ai_response:
                g.is_game_over=True
                return
            if g.next_question in ai_response:
                g.is_correct = True           
                g.correct_count += 1
                g.difficulty_level += 1
                    
    @staticmethod
    def hint_next_question(g: GameStatus):
        if g.difficulty_level > 9:
            return
        if g.difficulty_level == 9:
            g.conversation_history_hidden.append({"role": "user", "content": '[系统提示]:这是最后一题,答完这题无论对错都停止游戏'})
        
        try:
            next_question = g.questionIterator.get_next_question(g.difficulty_level)['question']
            g.conversation_history_hidden.append({"role": "user", "content": f'[系统提示]:下一题,{next_question}'})
            g.next_question=next_question
        except:
            return
        
    @staticmethod
    def restart(g:GameStatus,chat_history):
        g.conversation_history_hidden=[]
        g.correct_count=0
        g.is_correct=False
        g.difficulty_level=1
        g.is_game_over=False
        g.initial_message=[]
        g.next_question=None
        g.conversation_history_hidden.append({"role": "system", "content": GameController.system_prompt})
        GameController.get_initial_chat_display(g)

        chat_history=g.initial_message
        g.questionIterator.reset()

        return chat_history
    
    @staticmethod
    def clear_history(g:GameStatus):
        """清空对话历史"""
        g.conversation_history_hidden = []
        return "历史已清空"

    @staticmethod
    def show_history(g:GameStatus):
        """显示当前对话历史"""
        history_text = ""
        for msg in g.conversation_history_hidden:
            role = "用户" if msg["role"] == "user" else "AI" if msg["role"] == "assistant" else "系统"
            history_text += f"{role}: {msg['content']}\n\n"
        return history_text if history_text else "暂无对话历史"

    @staticmethod
    def save_conversation_history(g:GameStatus):
        """保存对话历史到JSON文件"""
        try:
            # 创建保存目录(如果不存在)
            os.makedirs("gradio_history", exist_ok=True)
            
            # 生成文件名(包含时间戳)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"gradio_history/conversation_{timestamp}.json"
            
            # 准备要保存的数据
            save_data = {
                "save_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                #"total_rounds": self.game_state["total_rounds"],
                "correct_count": g.correct_count,
                "conversation": g.conversation_history_hidden
            }
            
            # 保存到文件
            with open(filename, 'w', encoding='utf-8') as f:
                json.dump(save_data, f, ensure_ascii=False, indent=2)
            
            print(f"对话历史已保存到: {filename}")
            return
        
        except Exception as e:
            print(f"保存失败: {str(e)}")
            return

class QuestionIterator:
    question_set=None
    def __init__(self):
        if QuestionIterator.question_set is None:
            QuestionIterator.question_set=QuestionIterator.load_question_set(QUESTIONS_DATASET_PATH)
            # # For debug purpose, print the question set to make sure it is correctly loaded
            # for i, question_list in QuestionIterator.question_set.items():
            #     print(question_list)
            #     print("-" * 50)
        
        self.iterator_dict = {}
    
        for difficulty_str, question_list in QuestionIterator.question_set.items():
            difficulty = int(difficulty_str)
            shuffled_questions = question_list.copy()
            random.shuffle(shuffled_questions)
            self.iterator_dict[difficulty]=iter(shuffled_questions)



    def load_question_set(json_file_path):
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 将字典转换回便于使用的格式
        sub_dataframes_dict = {}
        for difficulty_str, records in data.items():
            difficulty = int(difficulty_str)
            sub_dataframes_dict[difficulty] = records
        
        return sub_dataframes_dict

        # 使用示例
        # loaded_data = load_sub_dataframes('sub_dataframes.json')
        # print(loaded_data[4])  # 访问难度4的数据

    def get_next_question(self,difficulty):
        '''
        generate next question
        @param
        difficulty: (int) the difficulty of the next question
        @return
        question: the question
        length: the length of the expected answer
        '''
        # 检查该难度是否有可用问题
    
        try:
            item = next(self.iterator_dict[difficulty])
            print(f"下一题: {item}")
        except StopIteration:
            raise Exception("这个难度的题出完了")
            
        
        return item

    def reset_difficulty(self, difficulty):
        """
        重置指定难度的使用记录
        @param difficulty: 要重置的难度
        """
        difficulty = int(difficulty)
        if str(difficulty) in QuestionIterator.question_set:
            shuffled_questions = QuestionIterator.question_set[difficulty].copy()
            random.shuffle(shuffled_questions)
            self.iterator_dict[difficulty]=iter(shuffled_questions)

    def reset(self):
        for difficulty in [1,2,3,4,5,6,7,8,9]:
            self.reset_difficulty(difficulty)

class GRServer:
    def __init__(self):
        # 使用线程安全的字典存储用户状态
        self.user_sessions: Dict[str, GameStatus] = {}
        self.session_lock = threading.Lock()

    def get_or_create_session(self, chatbot,session_id: str = None):

        """获取或创建用户会话"""
        if not session_id:
            session_id = str(uuid4())
        
        with self.session_lock:
            if session_id not in self.user_sessions:
                g=GameStatus()
                chatbot=g.initial_message
                self.user_sessions[session_id] = g

        return session_id,chatbot
    
    def session_chat(self,session_id,msg,chat_history):
        g=self.user_sessions[session_id]
        
        for updated_chat_history in GameController.chat_with_ai(g, msg, chat_history):
            yield "", chat_history
        
        return "", updated_chat_history
    
    def session_restart(self, session_id, chat_history):
        g=self.user_sessions[session_id]
        chat_history=GameController.restart(g, chat_history)
        return chat_history
    
    def session_save(self, session_id):
        g=self.user_sessions[session_id]
        GameController.save_conversation_history(g)

    def create_interface(self):
        # 创建Gradio界面
        with gr.Blocks(title="皇帝出题机") as demo:
            session_state = gr.State(value="")
            gr.Markdown("# 皇帝出题机")
        
            with gr.Row():
                with gr.Column(scale=2):
                    chatbot = gr.Chatbot(label="对话界面",height=600,type='messages')
                    msg = gr.Textbox(label="输入消息", placeholder="在这里输入你的消息...")

                    with gr.Column():
                        btn_send = gr.Button("发送", variant="primary")
                        with gr.Row():
                            btn_save = gr.Button("保存对话历史", variant="secondary")
                            btn_restart = gr.Button("重新开始", variant="secondary")
                
            # 页面加载时初始化会话
            demo.load(
                fn=self.get_or_create_session,
                inputs=[chatbot],
                outputs=[session_state,chatbot]
            )

            btn_send.click(self.session_chat, [session_state,msg, chatbot], [msg, chatbot])
            msg.submit(self.session_chat, [session_state, msg, chatbot], [msg, chatbot])  # 回车键触发
            btn_save.click(
                fn=self.session_save,
                inputs=[session_state],
                outputs=[]
            )
            btn_restart.click(fn=self.session_restart,inputs=[session_state,chatbot],outputs=[chatbot])
        
        demo.launch()

if __name__ == "__main__":
    server=GRServer()
    server.create_interface()
    # g=GameController()
    # question_iterator = QuestionIterator()

    # while True:
    #     try:
    #         question = question_iterator.get_next_question(4)
    #         print(f"问题: {question['question']}, 长度: {question['length_of_answer']}")
    #     except:
    #         break

    # question = question_iterator.get_next_question(6)
    # print(f"问题: {question['question']}, 长度: {question['length_of_answer']}")

    # question_iterator.reset_difficulty(4)
    #create_interface(g)