diff --git a/generate_images.py b/generate_images.py index 8a5f411..9fce8ab 100644 --- a/generate_images.py +++ b/generate_images.py @@ -10,33 +10,34 @@ from semdiffusers import SemanticEditPipeline parser = argparse.ArgumentParser(description='generate images') -parser.add_argument('--mode', default='generate', type=str, choices=['generate','edit'], +parser.add_argument('--mode', default='generate', type=str, choices=['generate', 'edit'], help='which edit to conduct') parser.add_argument('--split', default=0, type=int, help='split occupations into chunks to make parallel computation possible') args = parser.parse_args() - + def chunks(xs, n): n = max(1, n) - return list(xs[i:i+n] for i in range(0, len(xs), n)) + return list(xs[i:i + n] for i in range(0, len(xs), n)) cnn_face_detector = dlib.cnn_face_detection_model_v1('dlib_models/mmod_human_face_detector.dat') device = 'cuda' +model_name = "runwayml/stable-diffusion-v1-5" pipe = SemanticEditPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", + model_name, ).to(device) gen = torch.Generator(device=device) -num_im = 250 -chunk_size = 15 +num_im = 50 +chunk_size = 30 prompt = 'A photo of the face of a ' with open('occupations.txt') as f: - occupations = [line.rstrip() for line in f] + occupations = [line.split(" ", 1)[0] for line in f][1:] occupations = chunks(occupations, chunk_size)[args.split] @@ -51,11 +52,11 @@ def chunks(xs, n): 'seed': i, 'prompt': prompt + cl, 'num_images_per_prompt': 1 - } + } out = pipe(**params, generator=gen) image = out.images[0] # check if face exists in img with fairface detector - if face_existing(np.array(image), cnn_face_detector)==1: + if face_existing(np.array(image), cnn_face_detector) == 1: image.save(f"{pth}/image{j}.png") with open(f"{pth}/image{j}.json", 'w') as fp: json.dump(params, fp) @@ -63,10 +64,10 @@ def chunks(xs, n): else: print(f'no Face - {i}') i += 1 - - + + elif args.mode == 'edit': - dir_ = [True, False] + dir_ = [True, False] edit1 = ['male person', 'female person'] edit2 = edit1[::-1] @@ -86,19 +87,19 @@ def chunks(xs, n): params = json.load(f) gen.manual_seed(params['seed']) params_edit = {'guidance_scale': params['guidance_scale'], - 'seed': params['seed'], - 'prompt': params['prompt'], - 'num_images_per_prompt': params['num_images_per_prompt'], - 'editing_prompt': edit, - 'reverse_editing_direction': dir_, - 'edit_warmup_steps': 5, - 'edit_guidance_scale': 4, - 'edit_threshold': 0.95, - 'edit_momentum_scale': 0.5, - 'edit_mom_beta': 0.6} + 'seed': params['seed'], + 'prompt': params['prompt'], + 'num_images_per_prompt': params['num_images_per_prompt'], + 'editing_prompt': edit, + 'reverse_editing_direction': dir_, + 'edit_warmup_steps': 5, + 'edit_guidance_scale': 4, + 'edit_threshold': 0.95, + 'edit_momentum_scale': 0.5, + 'edit_mom_beta': 0.6} out = pipe(**params_edit, generator=gen) image = out.images[0] image.save(f"{pth_edit}/image{i}.png") with open(f"{pth_edit}/image{i}.json", 'w') as fp: json.dump(params_edit, fp) - +