mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-04-11 05:13:31 +02:00
llama3 dp torchrun
This commit is contained in:
parent
a8a34d56b9
commit
a397173104
|
|
@ -1,127 +1,223 @@
|
|||
from datetime import timedelta
|
||||
import argparse
|
||||
import base64
|
||||
import csv
|
||||
import os
|
||||
from io import BytesIO
|
||||
import torch.distributed as dist
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from .utils import PROMPTS, VideoTextDataset, read_file
|
||||
from .utils import read_file
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class CSVTextDataset(Dataset):
|
||||
def __init__(self, csv_path):
|
||||
self.df = pd.read_csv(csv_path)
|
||||
# assert text is in the columns
|
||||
assert 'text' in self.df.columns, "text column not found in the csv file"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < 0 or idx >= len(self.df):
|
||||
raise IndexError
|
||||
return self.df.iloc[idx]
|
||||
|
||||
def set_rank_and_world_size(self, rank, world_size):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.data_per_gpu = len(self) // world_size
|
||||
self.start_index = rank * self.data_per_gpu
|
||||
self.end_index = (rank + 1) * self.data_per_gpu if rank != world_size - 1 else len(self)
|
||||
self.df = self.df.iloc[self.start_index:self.end_index]
|
||||
|
||||
def write_to_csv(self, output_file, data, new_key):
|
||||
''' write the part of the df to a csv file corresponding to the rank and write self.data_list as a new column '''
|
||||
writer = csv.writer(open(output_file, "w"))
|
||||
columns = self.df.columns + [new_key]
|
||||
writer.writerow(columns)
|
||||
for index, row in self.df.iterrows():
|
||||
if index < self.start_index or index >= self.end_index:
|
||||
continue
|
||||
writer.writerow([*row, data[index - self.start_index]])
|
||||
writer.close()
|
||||
|
||||
def pad_left(sequences, padding_value=0):
|
||||
# Determine the maximum length of the sequences
|
||||
max_len = max([s.size(0) for s in sequences])
|
||||
# Create a list to hold the padded sequences
|
||||
padded_sequences = []
|
||||
for sequence in sequences:
|
||||
# Calculate the number of padding elements needed for this sequence
|
||||
num_padding = max_len - sequence.size(0)
|
||||
# Create a tensor of padding values
|
||||
padding = torch.full((num_padding,), padding_value, dtype=sequence.dtype).to(sequence.device)
|
||||
# Concatenate the padding and the sequence to pad on the left
|
||||
padded_sequence = torch.cat([padding, sequence], dim=0)
|
||||
padded_sequences.append(padded_sequence)
|
||||
# Stack the padded sequences into a batch
|
||||
batch = torch.stack(padded_sequences)
|
||||
return batch
|
||||
|
||||
def main(args):
|
||||
# ======================================================
|
||||
# 1. read video list
|
||||
# 1. init environment
|
||||
# ======================================================
|
||||
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
||||
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
||||
|
||||
# ======================================================
|
||||
# 2. Prep rank-wise dataloader
|
||||
# ======================================================
|
||||
dataframe = read_file(args.input)
|
||||
print("read data from {}".format(args.input))
|
||||
|
||||
output_file = os.path.splitext(args.input)[0] + "_llama3.csv"
|
||||
f = open(output_file, "w")
|
||||
writer = csv.writer(f)
|
||||
|
||||
columns = dataframe.columns + [args.key]
|
||||
dataset = CSVTextDataset(args.input)
|
||||
dataset.set_rank_and_world_size(
|
||||
dist.get_rank(),
|
||||
dist.get_world_size()
|
||||
)
|
||||
|
||||
output_file = os.path.splitext(args.input)[0] + f"_rank{dist.get_rank()}" + "_llama3.csv"
|
||||
|
||||
writer = csv.writer(open(output_file, "w"))
|
||||
columns = list(dataframe.columns) + [args.key]
|
||||
|
||||
writer.writerow(columns)
|
||||
|
||||
# add a new key named summary, write in csv file
|
||||
print("the processed data saved to {}".format(output_file))
|
||||
print("the processed data saved on this rank will be saved to {}".format(output_file))
|
||||
|
||||
def collate_fn(batch):
|
||||
return batch
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=2,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# 2. process using llama3 and prompt
|
||||
# ======================================================
|
||||
|
||||
print("Using model with the id {}".format(args.model_id))
|
||||
model_id = args.model_id
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
device_map=dist.get_rank() % torch.cuda.device_count(),
|
||||
)
|
||||
# .to(dist.get_rank() % torch.cuda.device_count())
|
||||
dist.barrier()
|
||||
print("======== Process data using LLAMA3 ========")
|
||||
|
||||
def extract(text, prompt):
|
||||
messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": text},
|
||||
def extract_batch(texts, prompt):
|
||||
input_ids_list = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": text}
|
||||
],
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)[0]
|
||||
for text in texts
|
||||
]
|
||||
|
||||
attention_mask_list = [
|
||||
torch.ones(input_ids.shape, dtype=torch.long, device=model.device)
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
|
||||
# input_ids_batch = pad_left(
|
||||
# input_ids_list, padding_value=tokenizer.eos_token_id
|
||||
# )
|
||||
|
||||
input_ids_batch = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids_list, batch_first=True, padding_value=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
attention_mask_batch = torch.nn.utils.rnn.pad_sequence(
|
||||
attention_mask_list, batch_first=True, padding_value=0
|
||||
)
|
||||
|
||||
# attention_mask_batch = pad_left(
|
||||
# attention_mask_list, padding_value=0
|
||||
# )
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
# The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
|
||||
# Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
|
||||
terminators = [tokenizer.eos_token_id,
|
||||
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
||||
]
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
input_ids_batch,
|
||||
max_new_tokens=512,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=attention_mask_batch,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=terminators,
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
top_p=0.9,
|
||||
# do_sample=True,
|
||||
# temperature=0.6,
|
||||
# top_p=0.9,
|
||||
)
|
||||
response = outputs[0][input_ids.shape[-1]:]
|
||||
response = tokenizer.decode(response, skip_special_tokens=True)
|
||||
return response
|
||||
|
||||
responses = []
|
||||
for i in range(len(texts)):
|
||||
response = outputs[i][input_ids_list[i].shape[-1]:]
|
||||
response = tokenizer.decode(response, skip_special_tokens=True)
|
||||
responses.append(response)
|
||||
|
||||
return responses
|
||||
|
||||
print("Processing starting...")
|
||||
if args.prompt == "":
|
||||
prompt = ("You are a AI assistant to extract user's text into keywords. "
|
||||
"For example: user: 'a woman on a news station talking about traffic.', you just need a list of keywords separate by , and covered by '[' and ']': '[woman, news station, traffic]' ")
|
||||
prompt = ("You are a AI assistant to extract objects from user's text. "
|
||||
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of objects separated by ',' and wrapped by '[' and ']': '[dog, person]' ")
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
print("Prompt: {}".format(prompt))
|
||||
key = args.key
|
||||
# prompt to process text, first input prompt to AI assistant, second input is the text to process
|
||||
# batch process, row['text'] to key extraction, for example: a woman on a news station talking about traffic. -> woman, news station, traffic
|
||||
|
||||
length = len(dataframe)
|
||||
for index, row in dataframe.iterrows():
|
||||
row_text = row['text']
|
||||
batch_size = args.batch_size
|
||||
# for i in tqdm(range(0, len(dataframe), batch_size)):
|
||||
for _, batch in enumerate(tqdm(dataloader)):
|
||||
# get the text column from the batch
|
||||
texts = [batch[i]['text'] for i in range(len(batch))]
|
||||
list_keywords = extract_batch(texts, prompt)
|
||||
|
||||
# process text
|
||||
keywords = extract(row_text, prompt)
|
||||
for idx, keywords in enumerate(list_keywords):
|
||||
try:
|
||||
keywords_start = keywords.find("[")
|
||||
keywords_end = keywords.find("]")
|
||||
keywords = keywords[keywords_start+1:keywords_end]
|
||||
except:
|
||||
keywords = "NONE_FOUND"
|
||||
|
||||
# process keywords
|
||||
keywords_start = keywords.find("[")
|
||||
keywords_end = keywords.find("]")
|
||||
keywords = keywords[keywords_start+1:keywords_end]
|
||||
|
||||
if index % 100 == 0:
|
||||
print("{}/{}".format(index, length))
|
||||
print(f"text: {row_text} "
|
||||
f"keywords: {keywords}")
|
||||
|
||||
# add a new key named summary, write in csv file
|
||||
writer.writerow([*row, keywords])
|
||||
|
||||
f.close()
|
||||
row = batch[idx]
|
||||
writer.writerow([*row, keywords])
|
||||
dist.barrier()
|
||||
# terminate distributed env
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-id", default="/home/data/models/meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument("--model-id", default="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--batch", type=int, default=16)
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument('--key', default='summary', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
Loading…
Reference in a new issue