之前自己从SemEnr下载下来用,安装了readme文件中提到的依赖库,结果到处报错。于是还是好朋友licyk给了我一个优秀的解决方案。感谢他!
接下来我也对其进行一定的使用及修改,这是我的项目仓库:https://github.com/HG-dev17/SemEnr

一、anaconda和Visual Studio安装

1.1 下载

下载链接:https://www.anaconda.com/products/distribution

1.2 安装

双击运行,一直下一步即可。(如不想安装在C盘,自行选择目录即可)
注意一定要选择添加进PATH,不然终端无法运行

1.3 Visual Studio安装

下载链接:https://visualstudio.microsoft.com/zh-hans/

二、安装项目

安装项目

由于有一些问题过期了,所以此方法可能存在一定的问题。请自行根据报错进行修改即可

下面的操作将在 PowerShell 中进行,右键桌面空白处点在终端中打开就能进入 PowerShell 终端。

  1. 现在将 MicroMamba 安装在 D 盘,为 MicroMamba 创建安装目录。
    1
    New-Item -ItemType Directory -Path "D:/micromamba"
  2. 下载 MicroMamba。
    1
    Invoke-WebRequest -Uri https://gitee.com/licyk/README-collection/releases/download/archive/micromamba.exe -OutFile "D:/micromamba/micromamba.exe"
    我也fork了一份
    1
    Invoke-WebRequest -Uri https://gitee.com/HG-dev17/download/releases/download/archive/micromamba.exe -OutFile "D:/micromamba/micromamba.exe"
  3. 初始化 MicroMamba 的配置。
    1
    2
    3
    & "D:/micromamba/micromamba.exe" shell hook -s powershell | Out-String | Invoke-Expression
    & "D:/micromamba/micromamba.exe" shell init -s powershell -r "D:/micromamba"
    . "$Env:USERPROFILE/Documents/WindowsPowerShell/profile.ps1"
    注意,如果遇到以下问题:
    1
    2
    . : 无法加载文件 C:\Users\用户名\Documents\WindowsPowerShell\profile.ps1
    因为在此系统上禁止运行脚本。有关详细信息,请参阅
    解决办法是输入命令允许启动所有脚本
    1
    Set-ExecutionPolicy RemoteSigned -Scope CurrentUser
  4. 为 MicroMamba 下载 Conda 配置文件。
    1
    Invoke-WebRequest -Uri https://gitee.com/licyk/README-collection/releases/download/archive/conda_config.yaml -OutFile "$Env:USERPROFILE/.condarc"
    我fork了一份
    1
    Invoke-WebRequest -Uri https://gitee.com/HG-dev17/download/releases/download/archive/conda_config.yaml -OutFile "$Env:USERPROFILE/.condarc"
  5. 为 SemEnr 创建 Conda 环境。
    1
    micromamba create --name semenr git -y
  6. 激活 SemEnr 的 Conda 环境。
    1
    micromamba activate semenr
  7. 待会在 D 盘下载 SemEnr 这个项目,所以切换到 D 盘目录里。
    1
    Set-Location D:
  8. 因为 https://github.com/licyk/SemEnr 这个 Github 地址在没有科学上网的情况下比较难连上,所以换成 Github 镜像源地址下载 SemEnr 项目。
    1
    git clone https://ghproxy.net/https://github.com/licyk/SemEnr
    我也fork了一份,地址是 https://github.com/HG-dev17/SemEnr ,所以你也可以直接用这个地址下载。
    1
    git clone https://github.com/HG-dev17/SemEnr
  9. 进入 SemEnr 项目文件夹。
    1
    Set-Location SemEnr
  10. 使用 MicroMamba 安装 SemEnr 的依赖。
    1
    micromamba install -f environment.yml -y
  11. 待会使用 Pip 安装剩余的项目依赖,为了保证下载速度,设置一下 Pip 镜像源。
    1
    pip config set global.index-url "https://mirrors.cloud.tencent.com/pypi/simple"
  12. 使用 Pip 安装剩余的依赖。
    1
    pip install -r requirements.txt
    现在 SemEnr 项目的运行环境就配置好了,下次要进入 SemEnr 项目的运行环境就使用下面的命令。
    1
    micromamba activate semenr
  13. 放置数据集到指定目录
    /data/github
  • 最后cd到运行目录,即可使用python .\main.py运行程序了

三、使用项目

项目使用过程中,我们发现loss值并没有记录在结果文件results/training_results.txt中,所以我们需要修改一下main.py文件,将loss值记录到结果文件中。
记得去掉加号,保持缩进

1
2
3
4
5
6
            if hist.history['val_loss'][0] < val_loss['loss']:
val_loss = {'loss': hist.history['val_loss'][0], 'epoch': i}
+ # 将 val_loss 写入文件
+ f1.write('Best Validation Loss: Epoch={}, Loss={}\n'.format(val_loss['epoch'], val_loss['loss']))
+ f1.flush()
print('Best: Loss = {}, Epoch = {}'.format(val_loss['loss'], val_loss['epoch']))

3.1分析项目

数据预处理

首先是数据预处理先看CreateCorpus.py文件

1
2
3
4
5
6
7
8
9
10
if __name__ == '__main__':
sourceFile = open("DeepCom_JAVA/train_tokens.txt", "r", encoding="utf-8")
corpusFile = open("DeepCom_JAVA/corpus_tokens.txt", "w", encoding="utf-8")
for num, line in enumerate(sourceFile):
print(num)
words = line.split()
for word in words:
corpusFile.write(word + '\n')
sourceFile.close()
corpusFile.close()

这个文件的作用是将DeepCom_JAVA/train_tokens.txt文件中的每一行按空格分割,然后将分割后的每个单词写入DeepCom_JAVA/corpus_tokens.txt文件中,每行一个单词。
接着是CreateVocab.py文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from collections import Counter

vocabWord = open("DeepCom_JAVA/corpus_tokens.txt", "r", encoding="utf-8")
processStaFile = open("DeepCom_JAVA/vocab.tokens.txt", "w", encoding="utf-8")
staTreeList = []
while 1:
word = vocabWord.readline().splitlines()
if not word:
break
staTreeList.append(word[0])

staTreeDic = Counter(staTreeList)

for k, v in staTreeDic.items():
if v >= 11:
processStaFile.write(k + '\n')

这个文件的作用是统计DeepCom_JAVA/corpus_tokens.txt文件中每个单词出现的次数,然后将出现次数大于等于11的单词写入DeepCom_JAVA/vocab.tokens.txt文件中,每行一个单词。
接着是Vocab2pkl.py文件
1
2
3
4
5
6
7
8
9
10
11
12
import codecs
import pickle

sbt_list = []
f = codecs.open('DeepCom_JAVA/vocab.tokens.txt', encoding='utf8', errors='replace').readlines()

for line in f:
line = line.strip()
sbt_list.append(line)
print(line)
sbt_dictionary = {value: index + 1 for index, value in enumerate(sbt_list)}
pickle.dump(sbt_dictionary,open("DeepCom_JAVA/vocab.tokens.pkl", 'wb'))

这个文件的作用是将DeepCom_JAVA/vocab.tokens.txt文件中的单词转换为字典,并将字典保存为DeepCom_JAVA/vocab.tokens.pkl文件。
最后是txt2pkl.py文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import pickle
import codecs

def desc2index(desc_dict):
desc = []
desc_index = []

fileDesc = codecs.open('DeepCom_JAVA/test_desc.txt', encoding='utf8',
errors='replace').readlines()
for line in fileDesc:
line = line.split()
desc.append(line)

for item in desc:
new_item = []
for word in item:
try:
new_item.append(desc_dict[word])
except:
new_item.append(3)
desc_index.append(new_item)
return desc_index


def tokens2index(tokens_dict):
tokens = []
tokens_index = []

fileTokens = codecs.open('DeepCom_JAVA/test_tokens.txt', encoding='utf8',
errors='replace').readlines()
for line in fileTokens:
line = line.split()
tokens.append(line)

for item in tokens:
new_item = []
for word in item:
try:
new_item.append(tokens_dict[word])
except:
new_item.append(3)
tokens_index.append(new_item)
return tokens_index

def simWords2index(desc_dict):
tokens = []
tokens_index = []
fileTokens = codecs.open('DeepCom_JAVA/test_IR_code_desc_sw.txt', encoding='utf8',
errors='replace').readlines()
for line in fileTokens:
line = line.split()
tokens.append(line)

for item in tokens:
new_item = []
for word in item:
try:
new_item.append(desc_dict[word])
except:
new_item.append(3)
tokens_index.append(new_item)


return tokens_index

if __name__ == '__main__':

desc_dict = pickle.load(open('DeepCom_JAVA/vocab.desc.pkl', 'rb')) ##描述
tokens_dict = pickle.load(open('DeepCom_JAVA/vocab.tokens.pkl', 'rb')) ##代码
desc = desc2index(desc_dict) ##描述
tokens = tokens2index(tokens_dict) ##代码
sim_desc=simWords2index(desc_dict) ##相似代码描述

pickle.dump(desc, open('DeepCom_JAVA/test.desc.pkl', 'wb')) ##描述
pickle.dump(tokens, open('DeepCom_JAVA/test.tokens.pkl', 'wb')) ##代码
pickle.dump(sim_desc, open('DeepCom_JAVA/test_IR_code_desc.pkl', 'wb')) ##相似代码描述

print('finish transfering data to index...')

这个文件的作用是将文本文件中的描述、代码片段和相似代码片段转换为索引,并将这些索引保存到pickle文件中。
至此,数据预处理部分已经完成。

四、优化项目

我们的研究希望能够优化代码搜索。根据原论文文档,我们尝试通过用大模型处理train_desc.txt来生成新的相似描述sim_desc.txt,然后通过train_tokens.txtsim_desc.txt制作出模型能够识别的pkl文件来训练模型,而不是仅仅依靠字词的token去制作sim_desc文本从而优化代码搜索。
因此,我们利用ollama api实现每一行的描述的相似描述生成,并保存到相似描述文本中。

  • 大模型的使用可以参考我的这篇文章:ollama api的使用,完成上述步骤后再进行以下步骤。
    注意,运行以下代码时必须启动open-webui服务
  • 建议使用如下代码前安装以下环境:(复制内容到requirements.txt文件中,然后使用pip install -r requirements.txt安装)

    1
    2
    3
    4
    5
    6
    requests==2.28.1
    tqdm==4.64.0
    concurrent.futures==3.9.2
    logging==0.5.1.2
    os==0.1.0
    signal==0.2.0

    这是实现基本功能的代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    import requests
    import json
    from tqdm import tqdm

    def askLocalQwen2Model(prompt):
    url = "http://localhost:11434/api/generate"

    # 创建要发送的JSON对象
    json_input = {
    "model": "qwen2.5-coder:latest",
    "prompt": prompt,
    "stream": False
    }

    try:
    # 发送POST请求
    response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(json_input))

    # 检查响应状态码
    if response.status_code == 200:
    # 解析JSON响应并提取response字段
    json_response = response.json()
    return json_response.get("response", "")
    else:
    print(f"Failed to get response from Qwen2 model. Status code: {response.status_code}")
    return ""
    except requests.RequestException as e:
    print(f"An error occurred: {e}")
    return ""

    # 读取输入文件
    input_file_path = "train_desc.txt"
    output_file_path = "final_desc.txt"
    train_prompt="Please generate only the similar descriptive text without any extra content. The output should be approximately the same length as the input and in English.\n"
    with open(input_file_path, 'r', encoding='utf-8') as input_file:
    lines = input_file.readlines()

    # 处理每一行并写入输出文件
    with tqdm(total=len(lines), desc="Processing", unit="line") as pbar:
    with open(output_file_path, 'w', encoding='utf-8') as output_file:
    for line in lines:
    # 去除行尾的换行符
    prompt = train_prompt + line.strip()

    # 使用模型处理生成相似文本
    result = askLocalQwen2Model(prompt)

    # 将结果写入输出文件
    output_file.write(result + "\n")

    # 更新进度条
    pbar.update(1)

    print("处理完成,结果已保存到final.txt")

    这是目前的代码,可能还会修改。提示词可以修改train_prompt,来控制生成的文本能够符合输出要求。

    使用完基本代码发现跑的非常慢,于是添加了多线程,代码如下:
    这是添加了多线程的代码
    修改完上述部分以后,发现处理完本项目的train_desc.txt文件竟然要300多个小时,于是想用多进程来加速处理,于是修改了代码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    import requests
    import json
    from tqdm import tqdm
    from concurrent.futures import ThreadPoolExecutor, as_completed
    import logging
    import os
    import signal
    # 读取输入文件
    input_file_path = "train_desc.txt" ##输入文件
    output_file_path = "final_desc.txt" ##输出文件
    model_name = "qwen2.5-coder:latest" ## 模型名称
    num_workers = 1000 ## 线程数(目前1000最佳)
    ## 模型提示词(要求)
    train_prompt = "Please generate only the similar descriptive text without any extra content. The output should be approximately the same length as the input and in English.\n"
    # temp_dir = "temp_files"
    # os.makedirs(temp_dir, exist_ok=True)


    # 配置日志记录
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    def askLocalQwen2Model(prompt):
    url = "http://localhost:11434/api/generate"

    # 创建要发送的JSON对象
    json_input = {
    "model": model_name,
    "prompt": prompt,
    "stream": False
    }

    try:
    # 发送POST请求
    response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(json_input))

    # 检查响应状态码
    if response.status_code == 200: #成功
    # 解析JSON响应并提取response字段
    json_response = response.json()
    return json_response.get("response", "")
    else:
    logging.error(f"Failed to get response from model. Status code: {response.status_code}")
    return ""
    except requests.RequestException as e:
    logging.error(f"An error occurred: {e}")
    return ""

    with open(input_file_path, 'r', encoding='utf-8') as input_file:
    lines = input_file.readlines()
    results = []
    def save_results_and_exit(signum, frame):
    with open(output_file_path, 'w', encoding='utf-8') as output_file:
    for i in range(len(lines)):
    result = next((res for index, res in results if index == i), None)
    if result is not None:
    output_file.write(result + "\n")
    logging.info("Results saved to file. Exiting... 多线程无法返回命令提示符,请自行关闭窗口")
    exit(0)
    # 捕获中断信号,防止Ctrl+C中断程序
    signal.signal(signal.SIGINT, save_results_and_exit)
    with tqdm(total=len(lines), desc="Processing", unit="line") as pbar:
    #并行处理数量
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
    futures = {executor.submit(askLocalQwen2Model, train_prompt + line.strip()): i for i, line in enumerate(lines)}

    for future in as_completed(futures):
    i = futures[future]
    try:
    result = future.result()
    results.append((i, result))
    except Exception as e:
    logging.error(f"Error processing line {i}: {e}")
    results.append((i, ""))

    pbar.update(1)
    # 将结果写入输出文件
    with open(output_file_path, 'w', encoding='utf-8') as output_file:
    for i in range(len(lines)):
    result = next((res for index, res in results if index == i), None)
    output_file.write(result + "\n")
    print("处理完成,结果已保存到final.txt")

    但是上述代码有一些问题,例如线程过多会报错,因此优化后的代码如下:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    import requests
    import json
    from tqdm import tqdm
    from concurrent.futures import ThreadPoolExecutor, as_completed
    import logging
    import os
    import signal
    import time

    # 读取输入文件
    input_file_path = "train_desc.txt" ##输入文件
    output_file_path = "final_desc.txt" ##输出文件
    model_name = "qwen2.5-coder:latest" ## 模型名称
    num_workers = 10 ## 线程数(目前10最佳)
    ## 模型提示词(要求)
    train_prompt = "Please generate only the similar descriptive text without any extra content. The output should be approximately the same length as the input and in English.\n"
    # 配置日志记录
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    def askLocalQwen2Model(prompt, max_retries=3):
    url = "http://localhost:11434/api/generate"

    # 创建要发送的JSON对象
    json_input = {
    "model": model_name,
    "prompt": prompt,
    "stream": False
    }

    for attempt in range(max_retries):
    try:
    # 发送POST请求
    response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(json_input))

    # 检查响应状态码
    if response.status_code == 200: #成功
    # 解析JSON响应并提取response字段
    json_response = response.json()
    return json_response.get("response", "")
    else:
    logging.warning(f"Attempt {attempt + 1} failed. Status code: {response.status_code}")
    except requests.RequestException as e:
    logging.warning(f"Attempt {attempt + 1} failed: {e}")

    # 如果所有尝试都失败,返回空字符串
    return ""
    def save_results_and_exit(signum, frame):
    with open(output_file_path, 'w', encoding='utf-8') as output_file:
    for i in range(len(lines)):
    if i in processed_indices:
    result = results[processed_indices.index(i)]
    output_file.write(result + "\n")
    logging.info("Results saved to file. Exiting...")
    print("程序已保存结果并退出。请手动关闭窗口或终端。")
    time.sleep(2) # 等待一段时间
    exit(0)
    # 捕获中断信号,防止Ctrl+C中断程序
    signal.signal(signal.SIGINT, save_results_and_exit)
    with open(input_file_path, 'r', encoding='utf-8') as input_file:
    lines = input_file.readlines()
    results = []
    processed_indices = []
    with tqdm(total=len(lines), desc="Processing", unit="line") as pbar:
    # 并行处理数量
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
    futures = {executor.submit(askLocalQwen2Model, train_prompt + line.strip()): i for i, line in enumerate(lines)}

    for future in as_completed(futures):
    i = futures[future]
    try:
    result = future.result()
    results.append(result)
    processed_indices.append(i)
    except Exception as e:
    logging.error(f"Error processing line {i}: {e}")

    pbar.update(1)
    # 将结果写入输出文件
    with open(output_file_path, 'w', encoding='utf-8') as output_file:
    for i in range(len(lines)):
    if i in processed_indices:
    result = results[processed_indices.index(i)]
    output_file.write(result + "\n")
    print("处理完成,结果已保存到final_desc.txt")

  • 生成文本以后,再通过文本预处理以后生成pkl,我们还可以使用以下代码查看pkl文件的内容

    1
    2
    3
    4
    import pickle
    desc_dict = pickle.load(open('vocab.desc.pkl', 'rb'))
    # desc_dict = pickle.load(open('想查看的pkl文件.pkl', 'rb'))
    print(desc_dict)