llama3 dp torchrun

This commit is contained in:
Tom Young 2024-06-13 05:55:47 +00:00
parent a8a34d56b9
commit a397173104

View file

@ -1,127 +1,223 @@
from datetime import timedelta
import argparse
import base64
import csv
import os
from io import BytesIO
import torch.distributed as dist
import argparse
import csv
import os
from torch.utils.data import Dataset
import requests
import tqdm
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .utils import PROMPTS, VideoTextDataset, read_file
from .utils import read_file
import warnings
import pandas as pd
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class CSVTextDataset(Dataset):
def __init__(self, csv_path):
self.df = pd.read_csv(csv_path)
# assert text is in the columns
assert 'text' in self.df.columns, "text column not found in the csv file"
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if idx < 0 or idx >= len(self.df):
raise IndexError
return self.df.iloc[idx]
def set_rank_and_world_size(self, rank, world_size):
self.rank = rank
self.world_size = world_size
self.data_per_gpu = len(self) // world_size
self.start_index = rank * self.data_per_gpu
self.end_index = (rank + 1) * self.data_per_gpu if rank != world_size - 1 else len(self)
self.df = self.df.iloc[self.start_index:self.end_index]
def write_to_csv(self, output_file, data, new_key):
''' write the part of the df to a csv file corresponding to the rank and write self.data_list as a new column '''
writer = csv.writer(open(output_file, "w"))
columns = self.df.columns + [new_key]
writer.writerow(columns)
for index, row in self.df.iterrows():
if index < self.start_index or index >= self.end_index:
continue
writer.writerow([*row, data[index - self.start_index]])
writer.close()
def pad_left(sequences, padding_value=0):
# Determine the maximum length of the sequences
max_len = max([s.size(0) for s in sequences])
# Create a list to hold the padded sequences
padded_sequences = []
for sequence in sequences:
# Calculate the number of padding elements needed for this sequence
num_padding = max_len - sequence.size(0)
# Create a tensor of padding values
padding = torch.full((num_padding,), padding_value, dtype=sequence.dtype).to(sequence.device)
# Concatenate the padding and the sequence to pad on the left
padded_sequence = torch.cat([padding, sequence], dim=0)
padded_sequences.append(padded_sequence)
# Stack the padded sequences into a batch
batch = torch.stack(padded_sequences)
return batch
def main(args):
# ======================================================
# 1. read video list
# 1. init environment
# ======================================================
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
# ======================================================
# 2. Prep rank-wise dataloader
# ======================================================
dataframe = read_file(args.input)
print("read data from {}".format(args.input))
output_file = os.path.splitext(args.input)[0] + "_llama3.csv"
f = open(output_file, "w")
writer = csv.writer(f)
columns = dataframe.columns + [args.key]
dataset = CSVTextDataset(args.input)
dataset.set_rank_and_world_size(
dist.get_rank(),
dist.get_world_size()
)
output_file = os.path.splitext(args.input)[0] + f"_rank{dist.get_rank()}" + "_llama3.csv"
writer = csv.writer(open(output_file, "w"))
columns = list(dataframe.columns) + [args.key]
writer.writerow(columns)
# add a new key named summary, write in csv file
print("the processed data saved to {}".format(output_file))
print("the processed data saved on this rank will be saved to {}".format(output_file))
def collate_fn(batch):
return batch
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=2,
batch_size=args.batch_size,
collate_fn=collate_fn,
shuffle=False,
)
# ======================================================
# 2. process using llama3 and prompt
# ======================================================
print("Using model with the id {}".format(args.model_id))
model_id = args.model_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
device_map=dist.get_rank() % torch.cuda.device_count(),
)
# .to(dist.get_rank() % torch.cuda.device_count())
dist.barrier()
print("======== Process data using LLAMA3 ========")
def extract(text, prompt):
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": text},
def extract_batch(texts, prompt):
input_ids_list = [
tokenizer.apply_chat_template(
[
{"role": "system", "content": prompt},
{"role": "user", "content": text}
],
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)[0]
for text in texts
]
attention_mask_list = [
torch.ones(input_ids.shape, dtype=torch.long, device=model.device)
for input_ids in input_ids_list
]
# input_ids_batch = pad_left(
# input_ids_list, padding_value=tokenizer.eos_token_id
# )
input_ids_batch = torch.nn.utils.rnn.pad_sequence(
input_ids_list, batch_first=True, padding_value=tokenizer.eos_token_id
)
attention_mask_batch = torch.nn.utils.rnn.pad_sequence(
attention_mask_list, batch_first=True, padding_value=0
)
# attention_mask_batch = pad_left(
# attention_mask_list, padding_value=0
# )
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
# Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
terminators = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=model.device)
outputs = model.generate(
input_ids,
input_ids_batch,
max_new_tokens=512,
attention_mask=attention_mask,
attention_mask=attention_mask_batch,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
# do_sample=True,
# temperature=0.6,
# top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
response = tokenizer.decode(response, skip_special_tokens=True)
return response
responses = []
for i in range(len(texts)):
response = outputs[i][input_ids_list[i].shape[-1]:]
response = tokenizer.decode(response, skip_special_tokens=True)
responses.append(response)
return responses
print("Processing starting...")
if args.prompt == "":
prompt = ("You are a AI assistant to extract user's text into keywords. "
"For example: user: 'a woman on a news station talking about traffic.', you just need a list of keywords separate by , and covered by '[' and ']': '[woman, news station, traffic]' ")
prompt = ("You are a AI assistant to extract objects from user's text. "
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of objects separated by ',' and wrapped by '[' and ']': '[dog, person]' ")
else:
prompt = args.prompt
print("Prompt: {}".format(prompt))
key = args.key
# prompt to process text, first input prompt to AI assistant, second input is the text to process
# batch process, row['text'] to key extraction, for example: a woman on a news station talking about traffic. -> woman, news station, traffic
length = len(dataframe)
for index, row in dataframe.iterrows():
row_text = row['text']
batch_size = args.batch_size
# for i in tqdm(range(0, len(dataframe), batch_size)):
for _, batch in enumerate(tqdm(dataloader)):
# get the text column from the batch
texts = [batch[i]['text'] for i in range(len(batch))]
list_keywords = extract_batch(texts, prompt)
# process text
keywords = extract(row_text, prompt)
for idx, keywords in enumerate(list_keywords):
try:
keywords_start = keywords.find("[")
keywords_end = keywords.find("]")
keywords = keywords[keywords_start+1:keywords_end]
except:
keywords = "NONE_FOUND"
# process keywords
keywords_start = keywords.find("[")
keywords_end = keywords.find("]")
keywords = keywords[keywords_start+1:keywords_end]
if index % 100 == 0:
print("{}/{}".format(index, length))
print(f"text: {row_text} "
f"keywords: {keywords}")
# add a new key named summary, write in csv file
writer.writerow([*row, keywords])
f.close()
row = batch[idx]
writer.writerow([*row, keywords])
dist.barrier()
# terminate distributed env
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", default="/home/data/models/meta-llama/Meta-Llama-3-8B-Instruct")
parser.add_argument("--model-id", default="meta-llama/Meta-Llama-3-8B-Instruct")
parser.add_argument("input", type=str, help="Path to the input CSV file")
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument('--key', default='summary', type=str)
args = parser.parse_args()
main(args)
main(args)