Stable Diffusion WebUI
[Module] Textual Inversion #2
이런저런 IT 이야기
2023. 6. 2. 10:26
반응형
dataset.py
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices
import random
import tqdm
from modules import devices, shared
import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
self.filename = filename
self.filename_text = filename_text
self.weight = weight
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
alpha_channel = None
if shared.state.interrupted:
raise Exception("interrupted")
try:
image = Image.open(path)
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
if use_weight and 'A' in image.getbands():
alpha_channel = image.getchannel('A')
image = image.convert('RGB')
if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
continue
text_filename = f"{os.path.splitext(path)[0]}.txt"
filename = os.path.basename(path)
if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
latent_sample = None
with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
#Perform latent sampling, even for random sampling.
#We need the sample dimensions for the weights
if latent_sampling_method == "deterministic":
if isinstance(latent_dist, DiagonalGaussianDistribution):
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
else:
latent_sampling_method = "once"
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
if use_weight and alpha_channel is not None:
channels, *latent_size = latent_sample.shape
weight_img = alpha_channel.resize(latent_size)
npweight = np.array(weight_img).astype(np.float32)
#Repeat for every channel in the latent sample
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight -= weight.min()
weight /= weight.mean()
elif use_weight:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight = torch.ones(latent_sample.shape)
else:
weight = None
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else:
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
del weight
self.length = len(self.dataset)
self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
if len(groups) > 1:
print("Buckets:")
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
print(f" {w}x{h}: {len(ids)}")
print()
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token)
return text
def __len__(self):
return self.length
def __getitem__(self, i):
entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
super().__init__(data_source)
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
self.base = [int(e) // batch_size for e in expected]
self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
def __len__(self):
return self.len
def __iter__(self):
b = self.batch_size
for g in self.groups:
shuffle(g)
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
shuffle(batches)
yield from batches
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
class BatchLoader:
def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
if all(entry.weight is not None for entry in data):
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
else:
self.weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)
class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
def pin_memory(self):
return self
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)
Stable Diffusion 웹 인터페이스의 modules.textual_inversion 모듈에 속하는 파일로, 텍스트 역송출 과정에서 사용되는 데이터셋 관련 기능을 제공합니다.
- PersonalizedBase 클래스: 텍스트 역송출을 위한 데이터셋을 정의하는 클래스입니다. 이 클래스는 PyTorch의 torch.utils.data.Dataset 클래스를 상속받아 구현되어 있으며, 데이터셋의 기본적인 동작을 제공합니다. 데이터셋은 이미지와 텍스트 조건으로 구성되며, 주어진 데이터 디렉토리에서 이미지와 조건 텍스트를 로드하여 제공합니다.
- PersonalizedDataLoader 클래스: PersonalizedBase 데이터셋을 DataLoader로 감싸주는 클래스입니다. 이 클래스는 PyTorch의 torch.utils.data.DataLoader 클래스를 상속받아 구현되어 있으며, 데이터셋을 미니배치로 나누어 학습에 활용할 수 있도록 합니다. 또한, 데이터셋의 샘플링 방법과 배치 크기, 핀 메모리 사용 여부 등을 설정할 수 있습니다.
- load_metadata 함수: 데이터셋 메타데이터를 로드하는 함수입니다. 데이터셋 디렉토리에 저장된 메타데이터 파일을 읽어와 필요한 정보를 반환합니다. 메타데이터는 데이터셋의 이미지 경로, 조건 텍스트, 가중치 등의 정보를 포함하고 있습니다.
- load_text 함수: 텍스트 데이터를 로드하는 함수입니다. 주어진 텍스트 파일에서 텍스트 데이터를 읽어와 리스트로 반환합니다.
- load_image 함수: 이미지 데이터를 로드하는 함수입니다. 주어진 이미지 파일 경로에서 이미지 데이터를 읽어와 NumPy 배열로 반환합니다.
- preprocess_image 함수: 이미지 데이터를 전처리하는 함수입니다. 이미지를 원하는 크기로 조정하고, 필요에 따라 정규화 또는 색상 채널 변환을 수행합니다.
- stack_conds 함수: 조건 텍스트 데이터를 전처리하는 함수입니다. 주어진 조건 텍스트를 텐서로 변환하여 반환합니다.
dataset.py 파일은 텍스트 역송출 모델에 필요한 데이터셋 관련 기능을 제공하여, 이미지와 조건 텍스트를 효율적으로 로드하고 처리할 수 있도록 도와줍니다. 이를 통해 모델 학습에 필요한 데이터를 쉽게 구성할 수 있습니다.
반응형