package pro.spss.server.agent.service.handler.tools;

import com.alibaba.fastjson2.JSONArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import pro.spss.server.agent.domain.constant.AlgoNameConfig;
import pro.spss.server.agent.domain.constant.AlgoPromptConfig;
import pro.spss.server.agent.domain.constant.ChatConstants;
import pro.spss.server.agent.domain.entity.AiChatResponse;
import pro.spss.server.agent.domain.enums.ConversationStateEnum;
import pro.spss.server.agent.domain.request.RequestParams;
import pro.spss.server.agent.domain.response.ResponseMessage;
import pro.spss.server.agent.service.handler.ConversationSupport;
import pro.spss.server.agent.service.prompt.PromptBuilderService;

/**
 * 用户直接指定算法处理器
 */
@Component
public class UserAlgoIntentHandler implements IntentStateHandler {

    private static final Logger logger = LoggerFactory.getLogger(UserAlgoIntentHandler.class);


    @Autowired
    private PromptBuilderService promptBuilderService;
    
    private final AlgoNameConfig algoNameConfig;
    private final ConversationSupport support;

    public UserAlgoIntentHandler(AlgoNameConfig algoNameConfig,
                                 ConversationSupport support) {
        this.algoNameConfig = algoNameConfig;
        this.support = support;
    }

    @Override
    public String getName() { return "user_algo_selection"; }

    @Override
    public String getTitle() { return "算法参数推荐工具"; }

    @Override
    public String getDesc() { return "用户已经上传数据，并且确定了使用某个算法进行数据分析，可以使用算法参数推荐工具对算法参数进行推荐配置。" +
            "注意只有当用户明确指定了某个算法的时候才使用这个工具，系统具体集成的算法请参考算法清单。"; }

    @Override
    public String exampleUserPrompt() {
        return "请使用随机森林算法对我的数据进行分类分析。";
    }

    @Override
    public String getCondition() {
        return "是否上传数据为是";
    }

    @Override
    public ResponseMessage handle(AiChatResponse intentResult, RequestParams requestParams, JSONArray messages, String prompt, String token) {
        String algoName = intentResult.getAlgoName();
        logger.debug("用户指定算法: {}", algoName);
        String algoId = algoNameConfig.getAlgoId(algoName);
        if (algoId == null) {
            ResponseMessage msg = new ResponseMessage();
            msg.setCode(303);
            String ms = "未找到或者名称有歧义的算法: " + algoName + "，请确认算法名称是否正确。";
            msg.setMessage(ms);
            msg.setResponse(ms);
            return msg;
        }
        requestParams.updateAlgoId(algoId);
        requestParams.updateAlgoName(algoName);
        // 构造增强上下文：数据摘要+用户意图+系统算法参数提示
        String enrichedPrompt = ChatConstants.DATA_PROMPT + requestParams.getDataSummary() + ChatConstants.USER_PROMPT + prompt;
        messages.add(ChatConstants.createMessage(ChatConstants.USER_ROLE, enrichedPrompt));
        String algoParamsPrompt = AlgoPromptConfig.getPromptById(algoId);
        algoParamsPrompt = promptBuilderService.build(algoParamsPrompt);
        messages.add(ChatConstants.createMessage(ChatConstants.ROLE_SYSTEM, algoParamsPrompt));
        requestParams.setState(ConversationStateEnum.PARAM_RECOMMENDATION);
        return support.sendRequestWithMessages(messages, requestParams);
    }
}
