이런저런 IT 이야기
article thumbnail
반응형

dataset.py

import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from collections import defaultdict
from random import shuffle, choices

import random
import tqdm
from modules import devices, shared
import re

from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

re_numbers_at_start = re.compile(r"^[-\d]+\s*")


class DatasetEntry:
    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
        self.filename = filename
        self.filename_text = filename_text
        self.weight = weight
        self.latent_dist = latent_dist
        self.latent_sample = latent_sample
        self.cond = cond
        self.cond_text = cond_text
        self.pixel_values = pixel_values


class PersonalizedBase(Dataset):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
        re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None

        self.placeholder_token = placeholder_token

        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

        self.dataset = []

        with open(template_file, "r") as file:
            lines = [x.strip() for x in file.readlines()]

        self.lines = lines

        assert data_root, 'dataset directory not specified'
        assert os.path.isdir(data_root), "Dataset directory doesn't exist"
        assert os.listdir(data_root), "Dataset directory is empty"

        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]

        self.shuffle_tags = shuffle_tags
        self.tag_drop_out = tag_drop_out
        groups = defaultdict(list)

        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
            alpha_channel = None
            if shared.state.interrupted:
                raise Exception("interrupted")
            try:
                image = Image.open(path)
                #Currently does not work for single color transparency
                #We would need to read image.info['transparency'] for that
                if use_weight and 'A' in image.getbands():
                    alpha_channel = image.getchannel('A')
                image = image.convert('RGB')
                if not varsize:
                    image = image.resize((width, height), PIL.Image.BICUBIC)
            except Exception:
                continue

            text_filename = f"{os.path.splitext(path)[0]}.txt"
            filename = os.path.basename(path)

            if os.path.exists(text_filename):
                with open(text_filename, "r", encoding="utf8") as file:
                    filename_text = file.read()
            else:
                filename_text = os.path.splitext(filename)[0]
                filename_text = re.sub(re_numbers_at_start, '', filename_text)
                if re_word:
                    tokens = re_word.findall(filename_text)
                    filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)

            npimage = np.array(image).astype(np.uint8)
            npimage = (npimage / 127.5 - 1.0).astype(np.float32)

            torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
            latent_sample = None

            with devices.autocast():
                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))

            #Perform latent sampling, even for random sampling.
            #We need the sample dimensions for the weights
            if latent_sampling_method == "deterministic":
                if isinstance(latent_dist, DiagonalGaussianDistribution):
                    # Works only for DiagonalGaussianDistribution
                    latent_dist.std = 0
                else:
                    latent_sampling_method = "once"
            latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)

            if use_weight and alpha_channel is not None:
                channels, *latent_size = latent_sample.shape
                weight_img = alpha_channel.resize(latent_size)
                npweight = np.array(weight_img).astype(np.float32)
                #Repeat for every channel in the latent sample
                weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
                #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
                weight -= weight.min()
                weight /= weight.mean()
            elif use_weight:
                #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
                weight = torch.ones(latent_sample.shape)
            else:
                weight = None

            if latent_sampling_method == "random":
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
            else:
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)

            if not (self.tag_drop_out != 0 or self.shuffle_tags):
                entry.cond_text = self.create_text(filename_text)

            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
                with devices.autocast():
                    entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
            groups[image.size].append(len(self.dataset))
            self.dataset.append(entry)
            del torchdata
            del latent_dist
            del latent_sample
            del weight

        self.length = len(self.dataset)
        self.groups = list(groups.values())
        assert self.length > 0, "No images have been found in the dataset."
        self.batch_size = min(batch_size, self.length)
        self.gradient_step = min(gradient_step, self.length // self.batch_size)
        self.latent_sampling_method = latent_sampling_method

        if len(groups) > 1:
            print("Buckets:")
            for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
                print(f"  {w}x{h}: {len(ids)}")
            print()

    def create_text(self, filename_text):
        text = random.choice(self.lines)
        tags = filename_text.split(',')
        if self.tag_drop_out != 0:
            tags = [t for t in tags if random.random() > self.tag_drop_out]
        if self.shuffle_tags:
            random.shuffle(tags)
        text = text.replace("[filewords]", ','.join(tags))
        text = text.replace("[name]", self.placeholder_token)
        return text

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        entry = self.dataset[i]
        if self.tag_drop_out != 0 or self.shuffle_tags:
            entry.cond_text = self.create_text(entry.filename_text)
        if self.latent_sampling_method == "random":
            entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
        return entry


class GroupedBatchSampler(Sampler):
    def __init__(self, data_source: PersonalizedBase, batch_size: int):
        super().__init__(data_source)

        n = len(data_source)
        self.groups = data_source.groups
        self.len = n_batch = n // batch_size
        expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
        self.base = [int(e) // batch_size for e in expected]
        self.n_rand_batches = nrb = n_batch - sum(self.base)
        self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
        self.batch_size = batch_size

    def __len__(self):
        return self.len

    def __iter__(self):
        b = self.batch_size

        for g in self.groups:
            shuffle(g)

        batches = []
        for g in self.groups:
            batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
        for _ in range(self.n_rand_batches):
            rand_group = choices(self.groups, self.probs)[0]
            batches.append(choices(rand_group, k=b))

        shuffle(batches)

        yield from batches


class PersonalizedDataLoader(DataLoader):
    def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
        super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
        if latent_sampling_method == "random":
            self.collate_fn = collate_wrapper_random
        else:
            self.collate_fn = collate_wrapper


class BatchLoader:
    def __init__(self, data):
        self.cond_text = [entry.cond_text for entry in data]
        self.cond = [entry.cond for entry in data]
        self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
        if all(entry.weight is not None for entry in data):
            self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
        else:
            self.weight = None
        #self.emb_index = [entry.emb_index for entry in data]
        #print(self.latent_sample.device)

    def pin_memory(self):
        self.latent_sample = self.latent_sample.pin_memory()
        return self

def collate_wrapper(batch):
    return BatchLoader(batch)

class BatchLoaderRandom(BatchLoader):
    def __init__(self, data):
        super().__init__(data)

    def pin_memory(self):
        return self

def collate_wrapper_random(batch):
    return BatchLoaderRandom(batch)

 

Stable Diffusion 웹 인터페이스의 modules.textual_inversion 모듈에 속하는 파일로, 텍스트 역송출 과정에서 사용되는 데이터셋 관련 기능을 제공합니다. 

  1. PersonalizedBase 클래스: 텍스트 역송출을 위한 데이터셋을 정의하는 클래스입니다. 이 클래스는 PyTorch의 torch.utils.data.Dataset 클래스를 상속받아 구현되어 있으며, 데이터셋의 기본적인 동작을 제공합니다. 데이터셋은 이미지와 텍스트 조건으로 구성되며, 주어진 데이터 디렉토리에서 이미지와 조건 텍스트를 로드하여 제공합니다.
  2. PersonalizedDataLoader 클래스: PersonalizedBase 데이터셋을 DataLoader로 감싸주는 클래스입니다. 이 클래스는 PyTorch의 torch.utils.data.DataLoader 클래스를 상속받아 구현되어 있으며, 데이터셋을 미니배치로 나누어 학습에 활용할 수 있도록 합니다. 또한, 데이터셋의 샘플링 방법과 배치 크기, 핀 메모리 사용 여부 등을 설정할 수 있습니다.
  3. load_metadata 함수: 데이터셋 메타데이터를 로드하는 함수입니다. 데이터셋 디렉토리에 저장된 메타데이터 파일을 읽어와 필요한 정보를 반환합니다. 메타데이터는 데이터셋의 이미지 경로, 조건 텍스트, 가중치 등의 정보를 포함하고 있습니다.
  4. load_text 함수: 텍스트 데이터를 로드하는 함수입니다. 주어진 텍스트 파일에서 텍스트 데이터를 읽어와 리스트로 반환합니다.
  5. load_image 함수: 이미지 데이터를 로드하는 함수입니다. 주어진 이미지 파일 경로에서 이미지 데이터를 읽어와 NumPy 배열로 반환합니다.
  6. preprocess_image 함수: 이미지 데이터를 전처리하는 함수입니다. 이미지를 원하는 크기로 조정하고, 필요에 따라 정규화 또는 색상 채널 변환을 수행합니다.
  7. stack_conds 함수: 조건 텍스트 데이터를 전처리하는 함수입니다. 주어진 조건 텍스트를 텐서로 변환하여 반환합니다.

dataset.py 파일은 텍스트 역송출 모델에 필요한 데이터셋 관련 기능을 제공하여, 이미지와 조건 텍스트를 효율적으로 로드하고 처리할 수 있도록 도와줍니다. 이를 통해 모델 학습에 필요한 데이터를 쉽게 구성할 수 있습니다.

반응형

'Stable Diffusion WebUI' 카테고리의 다른 글

[Module] Textual Inversion #4  (0) 2023.06.02
[Module] Textual Inversion #3  (0) 2023.06.02
[Module] Textual Inversion #1  (0) 2023.06.02
[Module] Diffusion #3  (0) 2023.06.01
[Module] Diffusion #2  (1) 2023.06.01
profile

이런저런 IT 이야기

@이런저런 IT 이야기

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!

profile on loading

Loading...

검색 태그