Skip to content

Commit

Permalink
chinese llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
DLLXW committed Aug 12, 2023
1 parent 8849c3e commit 3753dda
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ python data_process.py
#预训练
python pretrain.py
#SFT
python train_sft.py
python sft.py
```
根据自己算力的情况合理的调节以下参数,控制模型的计算量和参数量
- max_seq_len = 512
Expand Down
10 changes: 1 addition & 9 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def init_model():
n_layers=n_layers,
n_heads=n_heads,
n_kv_heads=n_heads,
vocab_size=65535,
vocab_size=64793,
multiple_of=multiple_of,
max_seq_len=max_seq_len,
dropout=dropout,
Expand Down Expand Up @@ -246,14 +246,6 @@ def init_model():
)
#
best_val_loss = 1e9
# attempt to derive vocab_size from the dataset
meta_path = os.path.join('./meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
meta_vocab_size = meta['vocab_size']
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
#-----init dataloader------
data_path_list=[
'./data/wiki.bin',
Expand Down
10 changes: 1 addition & 9 deletions sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def init_model():
n_layers=n_layers,
n_heads=n_heads,
n_kv_heads=n_heads,
vocab_size=65535,#64793,
vocab_size=64793,#64793,
multiple_of=multiple_of,
max_seq_len=max_seq_len,
dropout=dropout,
Expand Down Expand Up @@ -252,14 +252,6 @@ def init_model():
)
#
best_val_loss = 1e9
# attempt to derive vocab_size from the dataset
meta_path = os.path.join('./meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
meta_vocab_size = meta['vocab_size']
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
#-----init dataloader------
df_sft=pd.read_csv('./data/sft_data.csv')
input=[]
Expand Down

0 comments on commit 3753dda

Please sign in to comment.