mirror of
https://github.com/hpcaitech/Open-Sora.git
synced 2026-05-20 17:35:58 +02:00
small update
This commit is contained in:
parent
63e86f6dd6
commit
50ad9a1ca5
|
|
@ -65,6 +65,6 @@ log_every = 10
|
|||
ckpt_every = 500
|
||||
load = None
|
||||
|
||||
batch_size = 17 # only for logging
|
||||
batch_size = None
|
||||
lr = 2e-5
|
||||
grad_clip = 1.0
|
||||
|
|
|
|||
|
|
@ -169,11 +169,12 @@ def save(
|
|||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
|
||||
if lr_scheduler is not None:
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
sampler_start_idx = step * batch_size if batch_size is not None else None
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
"sample_start_index": step * batch_size,
|
||||
"sample_start_index": sampler_start_idx,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ def apply(df, func, **kwargs):
|
|||
return df.progress_apply(func, **kwargs)
|
||||
|
||||
|
||||
TRAIN_COLUMNS = ["path", "text", "num_frames", "fps", "height", "width", "aspect_ratio", "resolution", "text_len"]
|
||||
|
||||
# ======================================================
|
||||
# --info
|
||||
# ======================================================
|
||||
|
|
@ -359,7 +361,7 @@ def read_file(input_path):
|
|||
|
||||
def save_file(data, output_path):
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if not os.path.exists(output_dir):
|
||||
if not os.path.exists(output_dir) and output_dir != "":
|
||||
os.makedirs(output_dir)
|
||||
if output_path.endswith(".csv"):
|
||||
return data.to_csv(output_path, index=False)
|
||||
|
|
@ -415,6 +417,12 @@ def main(args):
|
|||
data = pd.merge(data, data_new[cols_to_use], on="path", how="inner")
|
||||
print(f"Intersection number of samples: {len(data)}.")
|
||||
|
||||
# train columns
|
||||
if args.train_column:
|
||||
all_columns = data.columns
|
||||
columns_to_drop = all_columns.difference(TRAIN_COLUMNS)
|
||||
data = data.drop(columns=columns_to_drop)
|
||||
|
||||
# get output path
|
||||
output_path = get_output_path(args, input_name)
|
||||
|
||||
|
|
@ -508,7 +516,11 @@ def main(args):
|
|||
assert "num_frames" in data.columns
|
||||
data = data[data["num_frames"] <= args.fmax]
|
||||
if args.hwmax is not None:
|
||||
assert "resolution" in data.columns
|
||||
if "resolution" not in data.columns:
|
||||
height = data["height"]
|
||||
width = data["width"]
|
||||
data["resolution"] = height * width
|
||||
breakpoint()
|
||||
data = data[data["resolution"] <= args.hwmax]
|
||||
if args.aesmin is not None:
|
||||
assert "aes" in data.columns
|
||||
|
|
@ -552,6 +564,7 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--intersection", type=str, default=None, help="keep the paths in csv from the dataset and merge columns"
|
||||
)
|
||||
parser.add_argument("--train-column", action="store_true", help="only keep the train column")
|
||||
|
||||
# IO-related
|
||||
parser.add_argument("--info", action="store_true", help="get the basic information of each video and image")
|
||||
|
|
|
|||
Loading…
Reference in a new issue