small update

This commit is contained in:
Zangwei Zheng 2024-04-15 17:36:20 +08:00
parent 63e86f6dd6
commit 50ad9a1ca5
3 changed files with 18 additions and 4 deletions

View file

@ -65,6 +65,6 @@ log_every = 10
ckpt_every = 500
load = None
batch_size = 17 # only for logging
batch_size = None
lr = 2e-5
grad_clip = 1.0

View file

@ -169,11 +169,12 @@ def save(
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
if lr_scheduler is not None:
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
sampler_start_idx = step * batch_size if batch_size is not None else None
running_states = {
"epoch": epoch,
"step": step,
"global_step": global_step,
"sample_start_index": step * batch_size,
"sample_start_index": sampler_start_idx,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))

View file

@ -32,6 +32,8 @@ def apply(df, func, **kwargs):
return df.progress_apply(func, **kwargs)
TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"]
# ======================================================
# --info
# ======================================================
@ -359,7 +361,7 @@ def read_file(input_path):
def save_file(data, output_path):
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
if not os.path.exists(output_dir) and output_dir != "":
os.makedirs(output_dir)
if output_path.endswith(".csv"):
return data.to_csv(output_path, index=False)
@ -415,6 +417,12 @@ def main(args):
data = pd.merge(data, data_new[cols_to_use], on="path", how="inner")
print(f"Intersection number of samples: {len(data)}.")
# train columns
if args.train_column:
all_columns = data.columns
columns_to_drop = all_columns.difference(TRAIN_COLUMNS)
data = data.drop(columns=columns_to_drop)
# get output path
output_path = get_output_path(args, input_name)
@ -508,7 +516,11 @@ def main(args):
assert "num_frames" in data.columns
data = data[data["num_frames"] <= args.fmax]
if args.hwmax is not None:
assert "resolution" in data.columns
if "resolution" not in data.columns:
height = data["height"]
width = data["width"]
data["resolution"] = height * width
breakpoint()
data = data[data["resolution"] <= args.hwmax]
if args.aesmin is not None:
assert "aes" in data.columns
@ -552,6 +564,7 @@ def parse_args():
parser.add_argument(
"--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns"
)
parser.add_argument("--train-column", action="store_true", help="only keep the train column")
# IO-related
parser.add_argument("--info", action="store_true", help="get the basic information of each video and image")