diff --git a/.gitignore b/.gitignore index 690df8f..97197c1 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ models/inswapper_128.onnx models/GFPGANv1.4.pth *.onnx models/DMDNet.pth +faceswap/ +.vscode/ diff --git a/README.md b/README.md index b045426..b203ad1 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,7 @@ options: --keep-audio keep original audio --keep-frames keep temporary frames --many-faces process every face + --map-faces map source target faces --nsfw-filter filter the NSFW image or video --video-encoder {libx264,libx265,libvpx-vp9} adjust output video encoder --video-quality [0-51] adjust output video quality diff --git a/modules/cluster_analysis.py b/modules/cluster_analysis.py new file mode 100644 index 0000000..0e7db03 --- /dev/null +++ b/modules/cluster_analysis.py @@ -0,0 +1,32 @@ +import numpy as np +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score +from typing import Any + + +def find_cluster_centroids(embeddings, max_k=10) -> Any: + inertia = [] + cluster_centroids = [] + K = range(1, max_k+1) + + for k in K: + kmeans = KMeans(n_clusters=k, random_state=0) + kmeans.fit(embeddings) + inertia.append(kmeans.inertia_) + cluster_centroids.append({"k": k, "centroids": kmeans.cluster_centers_}) + + diffs = [inertia[i] - inertia[i+1] for i in range(len(inertia)-1)] + optimal_centroids = cluster_centroids[diffs.index(max(diffs)) + 1]['centroids'] + + return optimal_centroids + +def find_closest_centroid(centroids: list, normed_face_embedding) -> list: + try: + centroids = np.array(centroids) + normed_face_embedding = np.array(normed_face_embedding) + similarities = np.dot(centroids, normed_face_embedding) + closest_centroid_index = np.argmax(similarities) + + return closest_centroid_index, centroids[closest_centroid_index] + except ValueError: + return None \ No newline at end of file diff --git a/modules/core.py b/modules/core.py index 9de11ca..55707e2 100644 --- a/modules/core.py +++ b/modules/core.py @@ -40,6 +40,7 @@ def parse_args() -> None: program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False) program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False) program.add_argument('--nsfw-filter', help='filter the NSFW image or video', dest='nsfw_filter', action='store_true', default=False) + program.add_argument('--map-faces', help='map source target faces', dest='map_faces', action='store_true', default=False) program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9']) program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]') program.add_argument('--live-mirror', help='The live camera display as you see it in the front-facing camera frame', dest='live_mirror', action='store_true', default=False) @@ -67,6 +68,7 @@ def parse_args() -> None: modules.globals.keep_frames = args.keep_frames modules.globals.many_faces = args.many_faces modules.globals.nsfw_filter = args.nsfw_filter + modules.globals.map_faces = args.map_faces modules.globals.video_encoder = args.video_encoder modules.globals.video_quality = args.video_quality modules.globals.live_mirror = args.live_mirror @@ -194,10 +196,13 @@ def start() -> None: # process image to videos if modules.globals.nsfw_filter and ui.check_and_ignore_nsfw(modules.globals.target_path, destroy): return - update_status('Creating temp resources...') - create_temp(modules.globals.target_path) - update_status('Extracting frames...') - extract_frames(modules.globals.target_path) + + if not modules.globals.map_faces: + update_status('Creating temp resources...') + create_temp(modules.globals.target_path) + update_status('Extracting frames...') + extract_frames(modules.globals.target_path) + temp_frame_paths = get_temp_frame_paths(modules.globals.target_path) for frame_processor in get_frame_processors_modules(modules.globals.frame_processors): update_status('Progressing...', frame_processor.NAME) diff --git a/modules/face_analyser.py b/modules/face_analyser.py index f2d46bf..2122784 100644 --- a/modules/face_analyser.py +++ b/modules/face_analyser.py @@ -1,8 +1,16 @@ +import os +import shutil from typing import Any import insightface +import cv2 +import numpy as np import modules.globals +from tqdm import tqdm from modules.typing import Frame +from modules.cluster_analysis import find_cluster_centroids, find_closest_centroid +from modules.utilities import get_temp_directory_path, create_temp, extract_frames, clean_temp, get_temp_frame_paths +from pathlib import Path FACE_ANALYSER = None @@ -29,3 +37,153 @@ def get_many_faces(frame: Frame) -> Any: return get_face_analyser().get(frame) except IndexError: return None + +def has_valid_map() -> bool: + for map in modules.globals.souce_target_map: + if "source" in map and "target" in map: + return True + return False + +def default_source_face() -> Any: + for map in modules.globals.souce_target_map: + if "source" in map: + return map['source']['face'] + return None + +def simplify_maps() -> Any: + centroids = [] + faces = [] + for map in modules.globals.souce_target_map: + if "source" in map and "target" in map: + centroids.append(map['target']['face'].normed_embedding) + faces.append(map['source']['face']) + + modules.globals.simple_map = {'source_faces': faces, 'target_embeddings': centroids} + return None + +def add_blank_map() -> Any: + try: + max_id = -1 + if len(modules.globals.souce_target_map) > 0: + max_id = max(modules.globals.souce_target_map, key=lambda x: x['id'])['id'] + + modules.globals.souce_target_map.append({ + 'id' : max_id + 1 + }) + except ValueError: + return None + +def get_unique_faces_from_target_image() -> Any: + try: + modules.globals.souce_target_map = [] + target_frame = cv2.imread(modules.globals.target_path) + many_faces = get_many_faces(target_frame) + i = 0 + + for face in many_faces: + x_min, y_min, x_max, y_max = face['bbox'] + modules.globals.souce_target_map.append({ + 'id' : i, + 'target' : { + 'cv2' : target_frame[int(y_min):int(y_max), int(x_min):int(x_max)], + 'face' : face + } + }) + i = i + 1 + except ValueError: + return None + + +def get_unique_faces_from_target_video() -> Any: + try: + modules.globals.souce_target_map = [] + frame_face_embeddings = [] + face_embeddings = [] + + print('Creating temp resources...') + clean_temp(modules.globals.target_path) + create_temp(modules.globals.target_path) + print('Extracting frames...') + extract_frames(modules.globals.target_path) + + temp_frame_paths = get_temp_frame_paths(modules.globals.target_path) + + i = 0 + for temp_frame_path in tqdm(temp_frame_paths, desc="Extracting face embeddings from frames"): + temp_frame = cv2.imread(temp_frame_path) + many_faces = get_many_faces(temp_frame) + + for face in many_faces: + face_embeddings.append(face.normed_embedding) + + frame_face_embeddings.append({'frame': i, 'faces': many_faces, 'location': temp_frame_path}) + i += 1 + + centroids = find_cluster_centroids(face_embeddings) + + for frame in frame_face_embeddings: + for face in frame['faces']: + closest_centroid_index, _ = find_closest_centroid(centroids, face.normed_embedding) + face['target_centroid'] = closest_centroid_index + + for i in range(len(centroids)): + modules.globals.souce_target_map.append({ + 'id' : i + }) + + temp = [] + for frame in tqdm(frame_face_embeddings, desc=f"Mapping frame embeddings to centroids-{i}"): + temp.append({'frame': frame['frame'], 'faces': [face for face in frame['faces'] if face['target_centroid'] == i], 'location': frame['location']}) + + modules.globals.souce_target_map[i]['target_faces_in_frame'] = temp + + # dump_faces(centroids, frame_face_embeddings) + default_target_face() + except ValueError: + return None + + +def default_target_face(): + for map in modules.globals.souce_target_map: + best_face = None + best_frame = None + for frame in map['target_faces_in_frame']: + if len(frame['faces']) > 0: + best_face = frame['faces'][0] + best_frame = frame + break + + for frame in map['target_faces_in_frame']: + for face in frame['faces']: + if face['det_score'] > best_face['det_score']: + best_face = face + best_frame = frame + + x_min, y_min, x_max, y_max = best_face['bbox'] + + target_frame = cv2.imread(best_frame['location']) + map['target'] = { + 'cv2' : target_frame[int(y_min):int(y_max), int(x_min):int(x_max)], + 'face' : best_face + } + + +def dump_faces(centroids: Any, frame_face_embeddings: list): + temp_directory_path = get_temp_directory_path(modules.globals.target_path) + + for i in range(len(centroids)): + if os.path.exists(temp_directory_path + f"/{i}") and os.path.isdir(temp_directory_path + f"/{i}"): + shutil.rmtree(temp_directory_path + f"/{i}") + Path(temp_directory_path + f"/{i}").mkdir(parents=True, exist_ok=True) + + for frame in tqdm(frame_face_embeddings, desc=f"Copying faces to temp/./{i}"): + temp_frame = cv2.imread(frame['location']) + + j = 0 + for face in frame['faces']: + if face['target_centroid'] == i: + x_min, y_min, x_max, y_max = face['bbox'] + + if temp_frame[int(y_min):int(y_max), int(x_min):int(x_max)].size > 0: + cv2.imwrite(temp_directory_path + f"/{i}/{frame['frame']}_{j}.png", temp_frame[int(y_min):int(y_max), int(x_min):int(x_max)]) + j += 1 \ No newline at end of file diff --git a/modules/globals.py b/modules/globals.py index be09102..16ed2b5 100644 --- a/modules/globals.py +++ b/modules/globals.py @@ -1,5 +1,5 @@ import os -from typing import List, Dict +from typing import List, Dict, Any ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) WORKFLOW_DIR = os.path.join(ROOT_DIR, 'workflow') @@ -9,6 +9,9 @@ file_types = [ ('Video', ('*.mp4','*.mkv')) ] +souce_target_map = [] +simple_map = {} + source_path = None target_path = None output_path = None @@ -17,6 +20,7 @@ keep_fps = None keep_audio = None keep_frames = None many_faces = None +map_faces = None color_correction = None # New global variable for color correction toggle nsfw_filter = None video_encoder = None diff --git a/modules/processors/frame/face_swapper.py b/modules/processors/frame/face_swapper.py index c65693e..caf493e 100644 --- a/modules/processors/frame/face_swapper.py +++ b/modules/processors/frame/face_swapper.py @@ -6,9 +6,10 @@ import threading import modules.globals import modules.processors.frame.core from modules.core import update_status -from modules.face_analyser import get_one_face, get_many_faces +from modules.face_analyser import get_one_face, get_many_faces, default_source_face from modules.typing import Face, Frame from modules.utilities import conditional_download, resolve_relative_path, is_image, is_video +from modules.cluster_analysis import find_closest_centroid FACE_SWAPPER = None THREAD_LOCK = threading.Lock() @@ -22,10 +23,10 @@ def pre_check() -> bool: def pre_start() -> bool: - if not is_image(modules.globals.source_path): + if not modules.globals.map_faces and not is_image(modules.globals.source_path): update_status('Select an image for source path.', NAME) return False - elif not get_one_face(cv2.imread(modules.globals.source_path)): + elif not modules.globals.map_faces and not get_one_face(cv2.imread(modules.globals.source_path)): update_status('No face in source path detected.', NAME) return False if not is_image(modules.globals.target_path) and not is_video(modules.globals.target_path): @@ -65,26 +66,98 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame: return temp_frame +def process_frame_v2(temp_frame: Frame, temp_frame_path: str = "") -> Frame: + if is_image(modules.globals.target_path): + if modules.globals.many_faces: + source_face = default_source_face() + for map in modules.globals.souce_target_map: + target_face = map['target']['face'] + temp_frame = swap_face(source_face, target_face, temp_frame) + + elif not modules.globals.many_faces: + for map in modules.globals.souce_target_map: + if "source" in map: + source_face = map['source']['face'] + target_face = map['target']['face'] + temp_frame = swap_face(source_face, target_face, temp_frame) + + elif is_video(modules.globals.target_path): + if modules.globals.many_faces: + source_face = default_source_face() + for map in modules.globals.souce_target_map: + target_frame = [f for f in map['target_faces_in_frame'] if f['location'] == temp_frame_path] + + for frame in target_frame: + for target_face in frame['faces']: + temp_frame = swap_face(source_face, target_face, temp_frame) + + elif not modules.globals.many_faces: + for map in modules.globals.souce_target_map: + if "source" in map: + target_frame = [f for f in map['target_faces_in_frame'] if f['location'] == temp_frame_path] + source_face = map['source']['face'] + + for frame in target_frame: + for target_face in frame['faces']: + temp_frame = swap_face(source_face, target_face, temp_frame) + else: + many_faces = get_many_faces(temp_frame) + if modules.globals.many_faces: + source_face = default_source_face() + if many_faces: + for target_face in many_faces: + temp_frame = swap_face(source_face, target_face, temp_frame) + + elif not modules.globals.many_faces: + if many_faces: + for target_face in many_faces: + closest_centroid_index, _ = find_closest_centroid(modules.globals.simple_map['target_embeddings'], target_face.normed_embedding) + + temp_frame = swap_face(modules.globals.simple_map['source_faces'][closest_centroid_index], target_face, temp_frame) + return temp_frame + + def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: - source_face = get_one_face(cv2.imread(source_path)) - for temp_frame_path in temp_frame_paths: - temp_frame = cv2.imread(temp_frame_path) - try: - result = process_frame(source_face, temp_frame) - cv2.imwrite(temp_frame_path, result) - except Exception as exception: - print(exception) - pass - if progress: - progress.update(1) + if not modules.globals.map_faces: + source_face = get_one_face(cv2.imread(source_path)) + for temp_frame_path in temp_frame_paths: + temp_frame = cv2.imread(temp_frame_path) + try: + result = process_frame(source_face, temp_frame) + cv2.imwrite(temp_frame_path, result) + except Exception as exception: + print(exception) + pass + if progress: + progress.update(1) + else: + for temp_frame_path in temp_frame_paths: + temp_frame = cv2.imread(temp_frame_path) + try: + result = process_frame_v2(temp_frame, temp_frame_path) + cv2.imwrite(temp_frame_path, result) + except Exception as exception: + print(exception) + pass + if progress: + progress.update(1) def process_image(source_path: str, target_path: str, output_path: str) -> None: - source_face = get_one_face(cv2.imread(source_path)) - target_frame = cv2.imread(target_path) - result = process_frame(source_face, target_frame) - cv2.imwrite(output_path, result) + if not modules.globals.map_faces: + source_face = get_one_face(cv2.imread(source_path)) + target_frame = cv2.imread(target_path) + result = process_frame(source_face, target_frame) + cv2.imwrite(output_path, result) + else: + if modules.globals.many_faces: + update_status('Many faces enabled. Using first source image. Progressing...', NAME) + target_frame = cv2.imread(output_path) + result = process_frame_v2(target_frame) + cv2.imwrite(output_path, result) def process_video(source_path: str, temp_frame_paths: List[str]) -> None: + if modules.globals.map_faces and modules.globals.many_faces: + update_status('Many faces enabled. Using first source image. Progressing...', NAME) modules.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames) diff --git a/modules/ui.py b/modules/ui.py index 1d91253..a8c6522 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -7,12 +7,14 @@ from PIL import Image, ImageOps import modules.globals import modules.metadata -from modules.face_analyser import get_one_face +from modules.face_analyser import get_one_face, get_unique_faces_from_target_image, get_unique_faces_from_target_video, add_blank_map, has_valid_map, simplify_maps from modules.capturer import get_video_frame, get_video_frame_total from modules.processors.frame.core import get_frame_processors_modules from modules.utilities import is_image, is_video, resolve_relative_path, has_image_extension ROOT = None +POPUP = None +POPUP_LIVE = None ROOT_HEIGHT = 700 ROOT_WIDTH = 600 @@ -22,6 +24,22 @@ PREVIEW_MAX_WIDTH = 1200 PREVIEW_DEFAULT_WIDTH = 960 PREVIEW_DEFAULT_HEIGHT = 540 +POPUP_WIDTH = 750 +POPUP_HEIGHT = 810 +POPUP_SCROLL_WIDTH = 740, +POPUP_SCROLL_HEIGHT = 700 + +POPUP_LIVE_WIDTH = 900 +POPUP_LIVE_HEIGHT = 820 +POPUP_LIVE_SCROLL_WIDTH = 890, +POPUP_LIVE_SCROLL_HEIGHT = 700 + +MAPPER_PREVIEW_MAX_HEIGHT = 100 +MAPPER_PREVIEW_MAX_WIDTH = 100 + +DEFAULT_BUTTON_WIDTH = 200 +DEFAULT_BUTTON_HEIGHT = 40 + RECENT_DIRECTORY_SOURCE = None RECENT_DIRECTORY_TARGET = None RECENT_DIRECTORY_OUTPUT = None @@ -31,6 +49,11 @@ preview_slider = None source_label = None target_label = None status_label = None +popup_status_label = None +popup_status_label_live = None +source_label_dict = {} +source_label_dict_live = {} +target_label_dict_live = {} img_ft, vid_ft = modules.globals.file_types @@ -102,7 +125,11 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C # nsfw_switch = ctk.CTkSwitch(root, text='NSFW filter', variable=nsfw_value, cursor='hand2', command=lambda: setattr(modules.globals, 'nsfw_filter', nsfw_value.get())) # nsfw_switch.place(relx=0.6, rely=0.7) - start_button = ctk.CTkButton(root, text='Start', cursor='hand2', command=lambda: select_output_path(start)) + map_faces = ctk.BooleanVar(value=modules.globals.map_faces) + map_faces_switch = ctk.CTkSwitch(root, text='Map faces', variable=map_faces, cursor='hand2', command=lambda: setattr(modules.globals, 'map_faces', map_faces.get())) + map_faces_switch.place(relx=0.1, rely=0.75) + + start_button = ctk.CTkButton(root, text='Start', cursor='hand2', command=lambda: analyze_target(start, root)) start_button.place(relx=0.15, rely=0.80, relwidth=0.2, relheight=0.05) stop_button = ctk.CTkButton(root, text='Destroy', cursor='hand2', command=lambda: destroy()) @@ -111,7 +138,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C preview_button = ctk.CTkButton(root, text='Preview', cursor='hand2', command=lambda: toggle_preview()) preview_button.place(relx=0.65, rely=0.80, relwidth=0.2, relheight=0.05) - live_button = ctk.CTkButton(root, text='Live', cursor='hand2', command=lambda: webcam_preview()) + live_button = ctk.CTkButton(root, text='Live', cursor='hand2', command=lambda: webcam_preview(root)) live_button.place(relx=0.40, rely=0.86, relwidth=0.2, relheight=0.05) status_label = ctk.CTkLabel(root, text=None, justify='center') @@ -124,6 +151,109 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C return root +def analyze_target(start: Callable[[], None], root: ctk.CTk): + if POPUP != None and POPUP.winfo_exists(): + update_status("Please complete pop-up or close it.") + return + + if modules.globals.map_faces: + modules.globals.souce_target_map = [] + + if is_image(modules.globals.target_path): + update_status('Getting unique faces') + get_unique_faces_from_target_image() + elif is_video(modules.globals.target_path): + update_status('Getting unique faces') + get_unique_faces_from_target_video() + + if len(modules.globals.souce_target_map) > 0: + create_source_target_popup(start, root, modules.globals.souce_target_map) + else: + update_status("No faces found in target") + else: + select_output_path(start) + +def create_source_target_popup(start: Callable[[], None], root: ctk.CTk, map: list) -> None: + global POPUP, popup_status_label + + POPUP = ctk.CTkToplevel(root) + POPUP.title("Source x Target Mapper") + POPUP.geometry(f"{POPUP_WIDTH}x{POPUP_HEIGHT}") + POPUP.focus() + + def on_submit_click(start): + if has_valid_map(): + POPUP.destroy() + select_output_path(start) + else: + update_pop_status("Atleast 1 source with target is required!") + + scrollable_frame = ctk.CTkScrollableFrame(POPUP, width=POPUP_SCROLL_WIDTH, height=POPUP_SCROLL_HEIGHT) + scrollable_frame.grid(row=0, column=0, padx=0, pady=0, sticky='nsew') + + def on_button_click(map, button_num): + map = update_popup_source(scrollable_frame, map, button_num) + + for item in map: + id = item['id'] + + button = ctk.CTkButton(scrollable_frame, text="Select source image", command=lambda id=id: on_button_click(map, id), width=DEFAULT_BUTTON_WIDTH, height=DEFAULT_BUTTON_HEIGHT) + button.grid(row=id, column=0, padx=50, pady=10) + + x_label = ctk.CTkLabel(scrollable_frame, text=f"X", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + x_label.grid(row=id, column=2, padx=10, pady=10) + + image = Image.fromarray(cv2.cvtColor(item['target']['cv2'], cv2.COLOR_BGR2RGB)) + image = image.resize((MAPPER_PREVIEW_MAX_WIDTH, MAPPER_PREVIEW_MAX_HEIGHT), Image.LANCZOS) + tk_image = ctk.CTkImage(image, size=image.size) + + target_image = ctk.CTkLabel(scrollable_frame, text=f"T-{id}", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + target_image.grid(row=id, column=3, padx=10, pady=10) + target_image.configure(image=tk_image) + + popup_status_label = ctk.CTkLabel(POPUP, text=None, justify='center') + popup_status_label.grid(row=1, column=0, pady=15) + + close_button = ctk.CTkButton(POPUP, text="Submit", command=lambda: on_submit_click(start)) + close_button.grid(row=2, column=0, pady=10) + + +def update_popup_source(scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int) -> list: + global source_label_dict + + source_path = ctk.filedialog.askopenfilename(title='select an source image', initialdir=RECENT_DIRECTORY_SOURCE, filetypes=[img_ft]) + + if "source" in map[button_num]: + map[button_num].pop("source") + source_label_dict[button_num].destroy() + del source_label_dict[button_num] + + if source_path == "": + return map + else: + cv2_img = cv2.imread(source_path) + face = get_one_face(cv2_img) + + if face: + x_min, y_min, x_max, y_max = face['bbox'] + + map[button_num]['source'] = { + 'cv2' : cv2_img[int(y_min):int(y_max), int(x_min):int(x_max)], + 'face' : face + } + + image = Image.fromarray(cv2.cvtColor(map[button_num]['source']['cv2'], cv2.COLOR_BGR2RGB)) + image = image.resize((MAPPER_PREVIEW_MAX_WIDTH, MAPPER_PREVIEW_MAX_HEIGHT), Image.LANCZOS) + tk_image = ctk.CTkImage(image, size=image.size) + + source_image = ctk.CTkLabel(scrollable_frame, text=f"S-{button_num}", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + source_image.grid(row=button_num, column=1, padx=10, pady=10) + source_image.configure(image=tk_image) + source_label_dict[button_num] = source_image + else: + update_pop_status("Face could not be detected in last upload!") + return map + def create_preview(parent: ctk.CTkToplevel) -> ctk.CTkToplevel: global preview_label, preview_slider @@ -147,6 +277,11 @@ def update_status(text: str) -> None: status_label.configure(text=text) ROOT.update() +def update_pop_status(text: str) -> None: + popup_status_label.configure(text=text) + +def update_pop_live_status(text: str) -> None: + popup_status_label_live.configure(text=text) def update_tumbler(var: str, value: bool) -> None: modules.globals.fp_ui[var] = value @@ -315,11 +450,17 @@ def update_preview(frame_number: int = 0) -> None: update_status('Processing succeed!') PREVIEW.deiconify() -def webcam_preview(): - if modules.globals.source_path is None: - # No image selected - return +def webcam_preview(root: ctk.CTk): + if not modules.globals.map_faces: + if modules.globals.source_path is None: + # No image selected + return + create_webcam_preview() + else: + modules.globals.souce_target_map = [] + create_source_target_popup_for_webcam(root, modules.globals.souce_target_map) +def create_webcam_preview(): global preview_label, PREVIEW camera = cv2.VideoCapture(0) # Use index for the webcam (adjust the index accordingly if necessary) @@ -340,10 +481,6 @@ def webcam_preview(): if not ret: break - # Select and save face image only once - if source_image is None and modules.globals.source_path: - source_image = get_one_face(cv2.imread(modules.globals.source_path)) - temp_frame = frame.copy() #Create a copy of the frame if modules.globals.live_mirror: @@ -352,8 +489,18 @@ def webcam_preview(): if modules.globals.live_resizable: temp_frame = fit_image_to_size(temp_frame, PREVIEW.winfo_width(), PREVIEW.winfo_height()) - for frame_processor in frame_processors: - temp_frame = frame_processor.process_frame(source_image, temp_frame) + if not modules.globals.map_faces: + # Select and save face image only once + if source_image is None and modules.globals.source_path: + source_image = get_one_face(cv2.imread(modules.globals.source_path)) + + for frame_processor in frame_processors: + temp_frame = frame_processor.process_frame(source_image, temp_frame) + else: + modules.globals.target_path = None + + for frame_processor in frame_processors: + temp_frame = frame_processor.process_frame_v2(temp_frame) image = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # Convert the image to RGB format to display it with Tkinter image = Image.fromarray(image) @@ -367,3 +514,151 @@ def webcam_preview(): camera.release() PREVIEW.withdraw() # Close preview window when loop is finished + + +def create_source_target_popup_for_webcam(root: ctk.CTk, map: list) -> None: + global POPUP_LIVE, popup_status_label_live + + POPUP_LIVE = ctk.CTkToplevel(root) + POPUP_LIVE.title("Source x Target Mapper") + POPUP_LIVE.geometry(f"{POPUP_LIVE_WIDTH}x{POPUP_LIVE_HEIGHT}") + POPUP_LIVE.focus() + + def on_submit_click(): + if has_valid_map(): + POPUP_LIVE.destroy() + simplify_maps() + create_webcam_preview() + else: + update_pop_live_status("Atleast 1 source with target is required!") + + def on_add_click(): + add_blank_map() + refresh_data(map) + update_pop_live_status("Please provide mapping!") + + popup_status_label_live = ctk.CTkLabel(POPUP_LIVE, text=None, justify='center') + popup_status_label_live.grid(row=1, column=0, pady=15) + + add_button = ctk.CTkButton(POPUP_LIVE, text="Add", command=lambda: on_add_click()) + add_button.place(relx=0.2, rely=0.92, relwidth=0.2, relheight=0.05) + + close_button = ctk.CTkButton(POPUP_LIVE, text="Submit", command=lambda: on_submit_click()) + close_button.place(relx=0.6, rely=0.92, relwidth=0.2, relheight=0.05) + + +def refresh_data(map: list): + global POPUP_LIVE + + scrollable_frame = ctk.CTkScrollableFrame(POPUP_LIVE, width=POPUP_LIVE_SCROLL_WIDTH, height=POPUP_LIVE_SCROLL_HEIGHT) + scrollable_frame.grid(row=0, column=0, padx=0, pady=0, sticky='nsew') + + def on_sbutton_click(map, button_num): + map = update_webcam_source(scrollable_frame, map, button_num) + + def on_tbutton_click(map, button_num): + map = update_webcam_target(scrollable_frame, map, button_num) + + for item in map: + id = item['id'] + + button = ctk.CTkButton(scrollable_frame, text="Select source image", command=lambda id=id: on_sbutton_click(map, id), width=DEFAULT_BUTTON_WIDTH, height=DEFAULT_BUTTON_HEIGHT) + button.grid(row=id, column=0, padx=30, pady=10) + + x_label = ctk.CTkLabel(scrollable_frame, text=f"X", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + x_label.grid(row=id, column=2, padx=10, pady=10) + + button = ctk.CTkButton(scrollable_frame, text="Select target image", command=lambda id=id: on_tbutton_click(map, id), width=DEFAULT_BUTTON_WIDTH, height=DEFAULT_BUTTON_HEIGHT) + button.grid(row=id, column=3, padx=20, pady=10) + + if "source" in item: + image = Image.fromarray(cv2.cvtColor(item['source']['cv2'], cv2.COLOR_BGR2RGB)) + image = image.resize((MAPPER_PREVIEW_MAX_WIDTH, MAPPER_PREVIEW_MAX_HEIGHT), Image.LANCZOS) + tk_image = ctk.CTkImage(image, size=image.size) + + source_image = ctk.CTkLabel(scrollable_frame, text=f"S-{id}", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + source_image.grid(row=id, column=1, padx=10, pady=10) + source_image.configure(image=tk_image) + + if "target" in item: + image = Image.fromarray(cv2.cvtColor(item['target']['cv2'], cv2.COLOR_BGR2RGB)) + image = image.resize((MAPPER_PREVIEW_MAX_WIDTH, MAPPER_PREVIEW_MAX_HEIGHT), Image.LANCZOS) + tk_image = ctk.CTkImage(image, size=image.size) + + target_image = ctk.CTkLabel(scrollable_frame, text=f"T-{id}", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + target_image.grid(row=id, column=4, padx=20, pady=10) + target_image.configure(image=tk_image) + + +def update_webcam_source(scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int) -> list: + global source_label_dict_live + + source_path = ctk.filedialog.askopenfilename(title='select an source image', initialdir=RECENT_DIRECTORY_SOURCE, filetypes=[img_ft]) + + if "source" in map[button_num]: + map[button_num].pop("source") + source_label_dict_live[button_num].destroy() + del source_label_dict_live[button_num] + + if source_path == "": + return map + else: + cv2_img = cv2.imread(source_path) + face = get_one_face(cv2_img) + + if face: + x_min, y_min, x_max, y_max = face['bbox'] + + map[button_num]['source'] = { + 'cv2' : cv2_img[int(y_min):int(y_max), int(x_min):int(x_max)], + 'face' : face + } + + image = Image.fromarray(cv2.cvtColor(map[button_num]['source']['cv2'], cv2.COLOR_BGR2RGB)) + image = image.resize((MAPPER_PREVIEW_MAX_WIDTH, MAPPER_PREVIEW_MAX_HEIGHT), Image.LANCZOS) + tk_image = ctk.CTkImage(image, size=image.size) + + source_image = ctk.CTkLabel(scrollable_frame, text=f"S-{button_num}", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + source_image.grid(row=button_num, column=1, padx=10, pady=10) + source_image.configure(image=tk_image) + source_label_dict_live[button_num] = source_image + else: + update_pop_live_status("Face could not be detected in last upload!") + return map + +def update_webcam_target(scrollable_frame: ctk.CTkScrollableFrame, map: list, button_num: int) -> list: + global target_label_dict_live + + target_path = ctk.filedialog.askopenfilename(title='select an target image', initialdir=RECENT_DIRECTORY_SOURCE, filetypes=[img_ft]) + + if "target" in map[button_num]: + map[button_num].pop("target") + target_label_dict_live[button_num].destroy() + del target_label_dict_live[button_num] + + if target_path == "": + return map + else: + cv2_img = cv2.imread(target_path) + face = get_one_face(cv2_img) + + if face: + x_min, y_min, x_max, y_max = face['bbox'] + + map[button_num]['target'] = { + 'cv2' : cv2_img[int(y_min):int(y_max), int(x_min):int(x_max)], + 'face' : face + } + + image = Image.fromarray(cv2.cvtColor(map[button_num]['target']['cv2'], cv2.COLOR_BGR2RGB)) + image = image.resize((MAPPER_PREVIEW_MAX_WIDTH, MAPPER_PREVIEW_MAX_HEIGHT), Image.LANCZOS) + tk_image = ctk.CTkImage(image, size=image.size) + + target_image = ctk.CTkLabel(scrollable_frame, text=f"T-{button_num}", width=MAPPER_PREVIEW_MAX_WIDTH, height=MAPPER_PREVIEW_MAX_HEIGHT) + target_image.grid(row=button_num, column=4, padx=20, pady=10) + target_image.configure(image=tk_image) + target_label_dict_live[button_num] = target_image + else: + update_pop_live_status("Face could not be detected in last upload!") + return map +