代码搜索SemEnr使用记录
之前自己从SemEnr下载下来用,安装了readme文件中提到的依赖库,结果到处报错。于是还是好朋友licyk给了我一个优秀的解决方案。感谢他!
接下来我也对其进行一定的使用及修改,这是我的项目仓库:https://github.com/HG-dev17/SemEnr
一、anaconda和Visual Studio安装
1.1 下载
下载链接:https://www.anaconda.com/products/distribution
1.2 安装
双击运行,一直下一步即可。(如不想安装在C盘,自行选择目录即可)注意一定要选择添加进PATH,不然终端无法运行
1.3 Visual Studio安装
下载链接:https://visualstudio.microsoft.com/zh-hans/
二、安装项目
安装项目
由于有一些问题过期了,所以此方法可能存在一定的问题。请自行根据报错进行修改即可
下面的操作将在 PowerShell
中进行,右键桌面空白处点在终端中打开就能进入 PowerShell
终端。
- 现在将 MicroMamba 安装在 D 盘,为 MicroMamba 创建安装目录。
1
New-Item -ItemType Directory -Path "D:/micromamba"
- 下载 MicroMamba。我也fork了一份
1
Invoke-WebRequest -Uri https://gitee.com/licyk/README-collection/releases/download/archive/micromamba.exe -OutFile "D:/micromamba/micromamba.exe"
1
Invoke-WebRequest -Uri https://gitee.com/HG-dev17/download/releases/download/archive/micromamba.exe -OutFile "D:/micromamba/micromamba.exe"
- 初始化 MicroMamba 的配置。注意,如果遇到以下问题:
1
2
3& "D:/micromamba/micromamba.exe" shell hook -s powershell | Out-String | Invoke-Expression
& "D:/micromamba/micromamba.exe" shell init -s powershell -r "D:/micromamba"
. "$Env:USERPROFILE/Documents/WindowsPowerShell/profile.ps1"解决办法是输入命令允许启动所有脚本1
2. : 无法加载文件 C:\Users\用户名\Documents\WindowsPowerShell\profile.ps1
因为在此系统上禁止运行脚本。有关详细信息,请参阅1
Set-ExecutionPolicy RemoteSigned -Scope CurrentUser
- 为 MicroMamba 下载 Conda 配置文件。我fork了一份
1
Invoke-WebRequest -Uri https://gitee.com/licyk/README-collection/releases/download/archive/conda_config.yaml -OutFile "$Env:USERPROFILE/.condarc"
1
Invoke-WebRequest -Uri https://gitee.com/HG-dev17/download/releases/download/archive/conda_config.yaml -OutFile "$Env:USERPROFILE/.condarc"
- 为 SemEnr 创建 Conda 环境。
1
micromamba create --name semenr git -y
- 激活 SemEnr 的 Conda 环境。
1
micromamba activate semenr
- 待会在 D 盘下载 SemEnr 这个项目,所以切换到 D 盘目录里。
1
Set-Location D:
- 因为 https://github.com/licyk/SemEnr 这个 Github 地址在没有科学上网的情况下比较难连上,所以换成 Github 镜像源地址下载 SemEnr 项目。我也fork了一份,地址是 https://github.com/HG-dev17/SemEnr ,所以你也可以直接用这个地址下载。
1
git clone https://ghproxy.net/https://github.com/licyk/SemEnr
1
git clone https://github.com/HG-dev17/SemEnr
- 进入 SemEnr 项目文件夹。
1
Set-Location SemEnr
- 使用 MicroMamba 安装 SemEnr 的依赖。
1
micromamba install -f environment.yml -y
- 待会使用 Pip 安装剩余的项目依赖,为了保证下载速度,设置一下 Pip 镜像源。
1
pip config set global.index-url "https://mirrors.cloud.tencent.com/pypi/simple"
- 使用 Pip 安装剩余的依赖。现在 SemEnr 项目的运行环境就配置好了,下次要进入 SemEnr 项目的运行环境就使用下面的命令。
1
pip install -r requirements.txt
1
micromamba activate semenr
- 放置数据集到指定目录
/data/github
- 最后cd到运行目录,即可使用
python .\main.py
运行程序了
三、使用项目
项目使用过程中,我们发现loss值并没有记录在结果文件results/training_results.txt
中,所以我们需要修改一下main.py文件,将loss值记录到结果文件中。记得去掉加号,保持缩进
1
2
3
4
5
6 if hist.history['val_loss'][0] < val_loss['loss']:
val_loss = {'loss': hist.history['val_loss'][0], 'epoch': i}
+ # 将 val_loss 写入文件
+ f1.write('Best Validation Loss: Epoch={}, Loss={}\n'.format(val_loss['epoch'], val_loss['loss']))
+ f1.flush()
print('Best: Loss = {}, Epoch = {}'.format(val_loss['loss'], val_loss['epoch']))
3.1分析项目
数据预处理
首先是数据预处理先看CreateCorpus.py文件1
2
3
4
5
6
7
8
9
10if __name__ == '__main__':
sourceFile = open("DeepCom_JAVA/train_tokens.txt", "r", encoding="utf-8")
corpusFile = open("DeepCom_JAVA/corpus_tokens.txt", "w", encoding="utf-8")
for num, line in enumerate(sourceFile):
print(num)
words = line.split()
for word in words:
corpusFile.write(word + '\n')
sourceFile.close()
corpusFile.close()
这个文件的作用是将DeepCom_JAVA/train_tokens.txt文件中的每一行按空格分割,然后将分割后的每个单词写入DeepCom_JAVA/corpus_tokens.txt文件中,每行一个单词。
接着是CreateVocab.py文件1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16from collections import Counter
vocabWord = open("DeepCom_JAVA/corpus_tokens.txt", "r", encoding="utf-8")
processStaFile = open("DeepCom_JAVA/vocab.tokens.txt", "w", encoding="utf-8")
staTreeList = []
while 1:
word = vocabWord.readline().splitlines()
if not word:
break
staTreeList.append(word[0])
staTreeDic = Counter(staTreeList)
for k, v in staTreeDic.items():
if v >= 11:
processStaFile.write(k + '\n')
这个文件的作用是统计DeepCom_JAVA/corpus_tokens.txt文件中每个单词出现的次数,然后将出现次数大于等于11的单词写入DeepCom_JAVA/vocab.tokens.txt文件中,每行一个单词。
接着是Vocab2pkl.py文件1
2
3
4
5
6
7
8
9
10
11
12import codecs
import pickle
sbt_list = []
f = codecs.open('DeepCom_JAVA/vocab.tokens.txt', encoding='utf8', errors='replace').readlines()
for line in f:
line = line.strip()
sbt_list.append(line)
print(line)
sbt_dictionary = {value: index + 1 for index, value in enumerate(sbt_list)}
pickle.dump(sbt_dictionary,open("DeepCom_JAVA/vocab.tokens.pkl", 'wb'))
这个文件的作用是将DeepCom_JAVA/vocab.tokens.txt文件中的单词转换为字典,并将字典保存为DeepCom_JAVA/vocab.tokens.pkl文件。
最后是txt2pkl.py文件1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78import pickle
import codecs
def desc2index(desc_dict):
desc = []
desc_index = []
fileDesc = codecs.open('DeepCom_JAVA/test_desc.txt', encoding='utf8',
errors='replace').readlines()
for line in fileDesc:
line = line.split()
desc.append(line)
for item in desc:
new_item = []
for word in item:
try:
new_item.append(desc_dict[word])
except:
new_item.append(3)
desc_index.append(new_item)
return desc_index
def tokens2index(tokens_dict):
tokens = []
tokens_index = []
fileTokens = codecs.open('DeepCom_JAVA/test_tokens.txt', encoding='utf8',
errors='replace').readlines()
for line in fileTokens:
line = line.split()
tokens.append(line)
for item in tokens:
new_item = []
for word in item:
try:
new_item.append(tokens_dict[word])
except:
new_item.append(3)
tokens_index.append(new_item)
return tokens_index
def simWords2index(desc_dict):
tokens = []
tokens_index = []
fileTokens = codecs.open('DeepCom_JAVA/test_IR_code_desc_sw.txt', encoding='utf8',
errors='replace').readlines()
for line in fileTokens:
line = line.split()
tokens.append(line)
for item in tokens:
new_item = []
for word in item:
try:
new_item.append(desc_dict[word])
except:
new_item.append(3)
tokens_index.append(new_item)
return tokens_index
if __name__ == '__main__':
desc_dict = pickle.load(open('DeepCom_JAVA/vocab.desc.pkl', 'rb')) ##描述
tokens_dict = pickle.load(open('DeepCom_JAVA/vocab.tokens.pkl', 'rb')) ##代码
desc = desc2index(desc_dict) ##描述
tokens = tokens2index(tokens_dict) ##代码
sim_desc=simWords2index(desc_dict) ##相似代码描述
pickle.dump(desc, open('DeepCom_JAVA/test.desc.pkl', 'wb')) ##描述
pickle.dump(tokens, open('DeepCom_JAVA/test.tokens.pkl', 'wb')) ##代码
pickle.dump(sim_desc, open('DeepCom_JAVA/test_IR_code_desc.pkl', 'wb')) ##相似代码描述
print('finish transfering data to index...')
这个文件的作用是将文本文件中的描述、代码片段和相似代码片段转换为索引,并将这些索引保存到pickle文件中。
至此,数据预处理部分已经完成。
四、优化项目
我们的研究希望能够优化代码搜索。根据原论文文档,我们尝试通过用大模型处理train_desc.txt
来生成新的相似描述sim_desc.txt
,然后通过train_tokens.txt
和sim_desc.txt
制作出模型能够识别的pkl
文件来训练模型,而不是仅仅依靠字词的token
去制作sim_desc
文本从而优化代码搜索。
因此,我们利用ollama api
实现每一行的描述的相似描述生成,并保存到相似描述文本中。
- 大模型的使用可以参考我的这篇文章:ollama api的使用,完成上述步骤后再进行以下步骤。
注意,运行以下代码时必须启动open-webui服务
建议使用如下代码前安装以下环境:(复制内容到
requirements.txt
文件中,然后使用pip install -r requirements.txt
安装)1
2
3
4
5
6requests==2.28.1
tqdm==4.64.0
concurrent.futures==3.9.2
logging==0.5.1.2
os==0.1.0
signal==0.2.0这是实现基本功能的代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54import requests
import json
from tqdm import tqdm
def askLocalQwen2Model(prompt):
url = "http://localhost:11434/api/generate"
# 创建要发送的JSON对象
json_input = {
"model": "qwen2.5-coder:latest",
"prompt": prompt,
"stream": False
}
try:
# 发送POST请求
response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(json_input))
# 检查响应状态码
if response.status_code == 200:
# 解析JSON响应并提取response字段
json_response = response.json()
return json_response.get("response", "")
else:
print(f"Failed to get response from Qwen2 model. Status code: {response.status_code}")
return ""
except requests.RequestException as e:
print(f"An error occurred: {e}")
return ""
# 读取输入文件
input_file_path = "train_desc.txt"
output_file_path = "final_desc.txt"
train_prompt="Please generate only the similar descriptive text without any extra content. The output should be approximately the same length as the input and in English.\n"
with open(input_file_path, 'r', encoding='utf-8') as input_file:
lines = input_file.readlines()
# 处理每一行并写入输出文件
with tqdm(total=len(lines), desc="Processing", unit="line") as pbar:
with open(output_file_path, 'w', encoding='utf-8') as output_file:
for line in lines:
# 去除行尾的换行符
prompt = train_prompt + line.strip()
# 使用模型处理生成相似文本
result = askLocalQwen2Model(prompt)
# 将结果写入输出文件
output_file.write(result + "\n")
# 更新进度条
pbar.update(1)
print("处理完成,结果已保存到final.txt")
这是目前的代码,可能还会修改。提示词可以修改train_prompt
,来控制生成的文本能够符合输出要求。使用完基本代码发现跑的非常慢,于是添加了多线程,代码如下:
这是添加了多线程的代码
修改完上述部分以后,发现处理完本项目的train_desc.txt
文件竟然要300多个小时,于是想用多进程来加速处理,于是修改了代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81import requests
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import os
import signal
# 读取输入文件
input_file_path = "train_desc.txt" ##输入文件
output_file_path = "final_desc.txt" ##输出文件
model_name = "qwen2.5-coder:latest" ## 模型名称
num_workers = 1000 ## 线程数(目前1000最佳)
## 模型提示词(要求)
train_prompt = "Please generate only the similar descriptive text without any extra content. The output should be approximately the same length as the input and in English.\n"
# temp_dir = "temp_files"
# os.makedirs(temp_dir, exist_ok=True)
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def askLocalQwen2Model(prompt):
url = "http://localhost:11434/api/generate"
# 创建要发送的JSON对象
json_input = {
"model": model_name,
"prompt": prompt,
"stream": False
}
try:
# 发送POST请求
response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(json_input))
# 检查响应状态码
if response.status_code == 200: #成功
# 解析JSON响应并提取response字段
json_response = response.json()
return json_response.get("response", "")
else:
logging.error(f"Failed to get response from model. Status code: {response.status_code}")
return ""
except requests.RequestException as e:
logging.error(f"An error occurred: {e}")
return ""
with open(input_file_path, 'r', encoding='utf-8') as input_file:
lines = input_file.readlines()
results = []
def save_results_and_exit(signum, frame):
with open(output_file_path, 'w', encoding='utf-8') as output_file:
for i in range(len(lines)):
result = next((res for index, res in results if index == i), None)
if result is not None:
output_file.write(result + "\n")
logging.info("Results saved to file. Exiting... 多线程无法返回命令提示符,请自行关闭窗口")
exit(0)
# 捕获中断信号,防止Ctrl+C中断程序
signal.signal(signal.SIGINT, save_results_and_exit)
with tqdm(total=len(lines), desc="Processing", unit="line") as pbar:
#并行处理数量
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(askLocalQwen2Model, train_prompt + line.strip()): i for i, line in enumerate(lines)}
for future in as_completed(futures):
i = futures[future]
try:
result = future.result()
results.append((i, result))
except Exception as e:
logging.error(f"Error processing line {i}: {e}")
results.append((i, ""))
pbar.update(1)
# 将结果写入输出文件
with open(output_file_path, 'w', encoding='utf-8') as output_file:
for i in range(len(lines)):
result = next((res for index, res in results if index == i), None)
output_file.write(result + "\n")
print("处理完成,结果已保存到final.txt")
但是上述代码有一些问题,例如线程过多会报错,因此优化后的代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83import requests
import json
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import os
import signal
import time
# 读取输入文件
input_file_path = "train_desc.txt" ##输入文件
output_file_path = "final_desc.txt" ##输出文件
model_name = "qwen2.5-coder:latest" ## 模型名称
num_workers = 10 ## 线程数(目前10最佳)
## 模型提示词(要求)
train_prompt = "Please generate only the similar descriptive text without any extra content. The output should be approximately the same length as the input and in English.\n"
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def askLocalQwen2Model(prompt, max_retries=3):
url = "http://localhost:11434/api/generate"
# 创建要发送的JSON对象
json_input = {
"model": model_name,
"prompt": prompt,
"stream": False
}
for attempt in range(max_retries):
try:
# 发送POST请求
response = requests.post(url, headers={"Content-Type": "application/json"}, data=json.dumps(json_input))
# 检查响应状态码
if response.status_code == 200: #成功
# 解析JSON响应并提取response字段
json_response = response.json()
return json_response.get("response", "")
else:
logging.warning(f"Attempt {attempt + 1} failed. Status code: {response.status_code}")
except requests.RequestException as e:
logging.warning(f"Attempt {attempt + 1} failed: {e}")
# 如果所有尝试都失败,返回空字符串
return ""
def save_results_and_exit(signum, frame):
with open(output_file_path, 'w', encoding='utf-8') as output_file:
for i in range(len(lines)):
if i in processed_indices:
result = results[processed_indices.index(i)]
output_file.write(result + "\n")
logging.info("Results saved to file. Exiting...")
print("程序已保存结果并退出。请手动关闭窗口或终端。")
time.sleep(2) # 等待一段时间
exit(0)
# 捕获中断信号,防止Ctrl+C中断程序
signal.signal(signal.SIGINT, save_results_and_exit)
with open(input_file_path, 'r', encoding='utf-8') as input_file:
lines = input_file.readlines()
results = []
processed_indices = []
with tqdm(total=len(lines), desc="Processing", unit="line") as pbar:
# 并行处理数量
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(askLocalQwen2Model, train_prompt + line.strip()): i for i, line in enumerate(lines)}
for future in as_completed(futures):
i = futures[future]
try:
result = future.result()
results.append(result)
processed_indices.append(i)
except Exception as e:
logging.error(f"Error processing line {i}: {e}")
pbar.update(1)
# 将结果写入输出文件
with open(output_file_path, 'w', encoding='utf-8') as output_file:
for i in range(len(lines)):
if i in processed_indices:
result = results[processed_indices.index(i)]
output_file.write(result + "\n")
print("处理完成,结果已保存到final_desc.txt")生成文本以后,再通过文本预处理以后生成pkl,我们还可以使用以下代码查看pkl文件的内容
1
2
3
4import pickle
desc_dict = pickle.load(open('vocab.desc.pkl', 'rb'))
# desc_dict = pickle.load(open('想查看的pkl文件.pkl', 'rb'))
print(desc_dict)