模型训练

PHPer 2025-10-12 11 0 0

这里对模型训练做一下整理 20251012 0628...

立即注册

更新于:2025-10-12 06:29:11
    您需要登录后才可以评论。 立即注册

    模型训练 DeepSeekCoder-demo.ipynb

    # ================================================================
    # 0. 一键包:环境 + 自动断点续训(LoRA)
    # ================================================================
    # 每次点 ▶️ 即可:有权重就续训,没有就新开
    !pip install -q -U transformers==4.40.* datasets accelerate peft bitsandbytes
    from google.colab import drive; drive.mount('/content/drive')
    # ========== 唯一需要填的区 ==========
    MODEL_NAME   = "deepseek-ai/deepseek-coder-7b-instruct-v1.5"
    OUTPUT_DIR   = "/content/drive/MyDrive/dsc7b_lora"   # 权重保存/读取路径
    DATA_JSONL   = "/content/drive/MyDrive/train.jsonl"  # 你的数据
    MAX_SAMPLES  = 10_000        # 想全量就 None
    MAX_LENGTH   = 512
    BATCH_SIZE   = 1
    GRAD_ACCUM   = 16
    EPOCHS       = 1
    LR           = 2e-4
    WARMUP       = 100
    SAVE_STEPS   = 200
    LOG_STEPS    = 10
    # ====================================
    import torch, json, os, random, glob
    from transformers import (
        AutoTokenizer, AutoModelForCausalLM,
        TrainingArguments, Trainer, DataCollatorForLanguageModeling
    )
    from peft import LoraConfig, get_peft_model, PeftModel, TaskType
    from datasets import load_dataset
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    # ---------- 1. 基础模型 ----------
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    # ---------- 2. 自动判断:是否已有 LoRA ----------
    if os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json")):
        print("🔍 检测到已有 LoRA,自动加载续训...")
        model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
        model.enable_input_require_grads()
    else:
        print("🚀 未检测到 LoRA,从零开始训练...")
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=64, lora_alpha=128, lora_dropout=0.05,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                            "gate_proj", "up_proj", "down_proj"]
        )
        model = get_peft_model(base_model, lora_config)
        model.print_trainable_parameters()
    # ---------- 3. 数据 ----------
    if not os.path.exists(DATA_JSONL):
        print("⚠️ 未找到你自己的数据,自动用官方 demo 数据示范...")
        !git clone https://github.com/deepseek-ai/deepseek-coder.git /tmp/ds_code
        all_text = []
        for f in glob.glob("/tmp/ds_code/data/*.jsonl"):
            with open(f) as fin:
                for line in fin:
                    sample = json.loads(line)
                    all_text.append(sample["input"] + sample["output"])
        random.shuffle(all_text)
        all_text = all_text[:MAX_SAMPLES]
        ds = load_dataset("json", data_files={"train": all_text}, split="train")
    else:
        ds = load_dataset("json", data_files=DATA_JSONL, split="train")
        if MAX_SAMPLES:
            ds = ds.select(range(min(MAX_SAMPLES, len(ds))))
    def tokenize(example):
        out = tokenizer(
            example["text"],
            truncation=True,
            max_length=MAX_LENGTH,
            return_overflowing_tokens=False,
        )
        return {"input_ids": out["input_ids"], "attention_mask": out["attention_mask"]}
    ds = ds.map(tokenize, num_proc=4, remove_columns=ds.column_names)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    # ---------- 4. 训练参数 ----------
    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        overwrite_output_dir=True,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        warmup_steps=WARMUP,
        logging_steps=LOG_STEPS,
        save_steps=SAVE_STEPS,
        save_total_limit=2,
        report_to="none",
        bf16=True,
        dataloader_drop_last=True,
        gradient_checkpointing=True,
    )
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=ds,
        data_collator=data_collator,
    )
    # ---------- 5. 开跑! ----------
    trainer.train()
    trainer.save_model(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    print("✅ 训练完成,权重已写入:", OUTPUT_DIR)     执行报错 :    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.0/138.0 kB 8.2 MB/s eta 0:00:00
       ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.0/9.0 MB 67.5 MB/s eta 0:00:00
       ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 506.3/506.3 kB 36.9 MB/s eta 0:00:00
       ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.1/60.1 MB 16.7 MB/s eta 0:00:00
       ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.8/42.8 MB 13.2 MB/s eta 0:00:00
       ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 62.9 MB/s eta 0:00:00
    ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
    pylibcudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
    cudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
    sentence-transformers 5.1.1 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.40.2 which is incompatible.
    Mounted at /content/drive
    ---------------------------------------------------------------------------
    ImportError                               Traceback (most recent call last)
    /usr/local/lib/python3.12/dist-packages/transformers/utils/import_utils.py in _get_module(self, module_name)
       1509         try:
    -> 1510             return importlib.import_module("." + module_name, self.__name__)
       1511         except Exception as e:
    14 frames
    ImportError: cannot import name 'EncoderDecoderCache' from 'transformers' (/usr/local/lib/python3.12/dist-packages/transformers/__init__.py)
    The above exception was the direct cause of the following exception:
    RuntimeError                              Traceback (most recent call last)
    /usr/local/lib/python3.12/dist-packages/transformers/utils/import_utils.py in _get_module(self, module_name)
       1510             return importlib.import_module("." + module_name, self.__name__)
       1511         except Exception as e:
    -> 1512             raise RuntimeError(
       1513                 f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
       1514                 f" traceback):\n{e}"
    RuntimeError: Failed to import transformers.trainer because of the following error (look up to see its traceback):
    cannot import name 'EncoderDecoderCache' from 'transformers' (/usr/local/lib/python3.12/dist-packages/transformers/__init__.py)
    

    20251012 0629 ...

    立即注册
    更新于:2025-10-12 06:32:03
    相关内容

    Yii2模型简介

    Yii2模型简介-属性

    Yii2模型简介-场景

    Yii2模型简介-验证规则

    Yii2模型简介-块赋值

    Yii2模型简介-非安全属性

    Yii2模型简介-数据导出

    Yii2模型简介-字段

    Yii2模型简介-最佳实践

    Python与模型相关知识以及问题的整理

    学习模型蒸馏之蒸馏SoloSpeech 2025年8月

    Google colab 测试运行SoloSpeech蒸馏项目中教师模型的训练

    学习模型蒸馏之蒸馏SoloSpeech 2025年8月 与通义对话

    腾讯云代码助手(Tencent Cloud CodeBuddy)插件在VS Code上

    线上训练数据的一些技巧

    云主机选择 试试 DigitalOcean 毕竟有新加坡服务器。

    云盘选择

    Linux/Ubuntu服务器命令行使用百度网盘

    SoloSpeech 模型训练终于有了眉目 20250829 2325

    各种和模型训练相关的工具

    相关问题报错

    python 调式代码的几种方法

    python报错 ModuleNotFoundError: No module named 'solospeech'​

    如何用有效的用conda安装python扩展

    SoloSpeech 训练的扩展安装

    python的一些包或扩展依赖于torch,会在安装的时候安装上torch的CPU版

    模型训练过程中的报错 unexpected pos 88457920 vs 88457808

    模型训练平台汇总

    Copilot的能力不低,不可小觑 20250902

    关于魔塔的静默提示,解决静默提醒提示。

    python -m py_compile "d:\python\SoloSpeech\solospeech\stable_audio_v...

    线上平台和CPU服务器压力测试

    Python与模型相关知识以及问题的整理之二

    通义千问:为什么说 Shell 函数真香?

    在 Windows 11 中,要让 PowerShell 启动时自动激活 conda activate train...

    CoPilot用Claude Sonnet 4模型调试多平台自动安装python训练模型或机器学习环境

    推荐内容

    怎样使用V2Ray代理和SSTap玩如魔兽世界/绝地求生/LOL台服/战地3/黑色沙漠/彩...

    sstap游戏代理教程 从此玩如魔兽世界/绝地求生/LOL台服/战地3/黑色沙漠/彩虹六...

    BT磁力搜索网站汇总和找不到的资源

    什么是磁力链接,您如何使用?

    Z-Library:全球最大的数字图书馆/含打不开的解决方案/镜像

    使用V2Ray的mKCP协议加速游戏

    v2rayN已停止工作

    【车险课堂】什么是无赔款优待系数ncd,你“造”吗?