Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from invokeai.backend.model_manager.config import (
AnyModelConfig,
ControlAdapterDefaultSettings,
LoraModelDefaultSettings,
MainModelDefaultSettings,
)
from invokeai.backend.model_manager.taxonomy import (
Expand Down Expand Up @@ -83,8 +84,8 @@ class ModelRecordChanges(BaseModelExcludeNull):
file_size: Optional[int] = Field(description="Size of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
default_settings: Optional[MainModelDefaultSettings | LoraModelDefaultSettings | ControlAdapterDefaultSettings] = (
Field(description="Default settings for this model", default=None)
)

# Checkpoint-specific changes
Expand Down
10 changes: 9 additions & 1 deletion invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class MainModelDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")


class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
model_config = ConfigDict(extra="forbid")


class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
Expand Down Expand Up @@ -287,6 +292,9 @@ class LoRAConfigBase(ABC, BaseModel):

type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[LoraModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)

@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
Expand Down Expand Up @@ -748,7 +756,7 @@ def get_model_discriminator_value(v: Any) -> str:
]

AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings]


class ModelConfigFactory:
Expand Down
7 changes: 7 additions & 0 deletions invokeai/backend/model_manager/legacy_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AnyModelConfig,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
LoraModelDefaultSettings,
MainModelDefaultSettings,
ModelConfigFactory,
SubmodelDefinition,
Expand Down Expand Up @@ -217,6 +218,8 @@ def probe(
if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
if fields["type"] in {ModelType.LoRA}:
fields["default_settings"] = get_default_settings_lora()
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])

Expand Down Expand Up @@ -543,6 +546,10 @@ def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAd
return None


def get_default_settings_lora() -> LoraModelDefaultSettings:
return LoraModelDefaultSettings()


def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
Expand Down
3 changes: 3 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,9 @@
}
}
},
"lora": {
"weight": "Weight"
},
"metadata": {
"allPrompts": "All Prompts",
"cfgScale": "CFG scale",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type { SliceConfig } from 'app/store/types';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { type LoRA, zLoRA } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/system/store/configSlice';
import type { LoRAModelConfig } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import z from 'zod';
Expand All @@ -13,11 +14,6 @@ const zLoRAsState = z.object({
});
type LoRAsState = z.infer<typeof zLoRAsState>;

const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};

const getInitialState = (): LoRAsState => ({
loras: [],
});
Expand All @@ -32,6 +28,10 @@ const slice = createSlice({
reducer: (state, action: PayloadAction<{ model: LoRAModelConfig; id: string }>) => {
const { model, id } = action.payload;
const parsedModel = zModelIdentifierField.parse(model);
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: model.default_settings?.weight ?? DEFAULT_LORA_WEIGHT_CONFIG.initial,
isEnabled: true,
};
state.loras.push({ ...defaultLoRAConfig, model: parsedModel, id });
},
prepare: (payload: { model: LoRAModelConfig }) => ({ payload: { ...payload, id: uuidv4() } }),
Expand Down Expand Up @@ -87,3 +87,7 @@ export const lorasSliceConfig: SliceConfig<typeof slice> = {

export const selectLoRAsSlice = (state: RootState) => state.loras;
export const selectAddedLoRAs = createSelector(selectLoRAsSlice, (loras) => loras.loras);
export const buildSelectLoRA = (id: string) =>
createSelector([selectLoRAsSlice], (loras) => {
return selectLoRA(loras, id);
});
29 changes: 12 additions & 17 deletions invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,22 @@ import {
Switch,
Text,
} from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import {
buildSelectLoRA,
loraDeleted,
loraIsEnabledChanged,
loraWeightChanged,
selectLoRAsSlice,
} from 'features/controlLayers/store/lorasSlice';
import type { LoRA } from 'features/controlLayers/store/types';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo } from 'react';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';

const marks = [-1, 0, 1, 2];

export const LoRACard = memo((props: { id: string }) => {
const selectLoRA = useMemo(
() => createSelector(selectLoRAsSlice, ({ loras }) => loras.find(({ id }) => id === props.id)),
[props.id]
);
const selectLoRA = useMemo(() => buildSelectLoRA(props.id), [props.id]);
const lora = useAppSelector(selectLoRA);

if (!lora) {
Expand Down Expand Up @@ -83,22 +78,22 @@ const LoRAContent = memo(({ lora }: { lora: LoRA }) => {
<CompositeSlider
value={lora.weight}
onChange={handleChange}
min={-1}
max={2}
step={0.01}
marks={marks}
defaultValue={0.75}
min={DEFAULT_LORA_WEIGHT_CONFIG.sliderMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.sliderMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
marks={DEFAULT_LORA_WEIGHT_CONFIG.marks.slice()}
defaultValue={DEFAULT_LORA_WEIGHT_CONFIG.initial}
isDisabled={!lora.isEnabled}
/>
<CompositeNumberInput
value={lora.weight}
onChange={handleChange}
min={-10}
max={10}
step={0.01}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
w={20}
flexShrink={0}
defaultValue={0.75}
defaultValue={DEFAULT_LORA_WEIGHT_CONFIG.initial}
isDisabled={!lora.isEnabled}
/>
</CardBody>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { isNil } from 'es-toolkit/compat';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import type { LoRAModelConfig } from 'services/api/types';

export const useLoRAModelDefaultSettings = (modelConfig: LoRAModelConfig) => {
const defaultSettingsDefaults = useMemo(() => {
return {
weight: {
isEnabled: !isNil(modelConfig?.default_settings?.weight),
value: modelConfig?.default_settings?.weight ?? DEFAULT_LORA_WEIGHT_CONFIG.initial,
},
};
}, [modelConfig?.default_settings]);

return defaultSettingsDefaults;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/system/store/configSlice';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

type DefaultWeight = LoRAModelDefaultSettingsFormData['weight'];

export const DefaultWeight = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weight'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const onChange = useCallback(
(v: number) => {
const updatedValue = {
...field.value,
value: v,
};
field.onChange(updatedValue);
},
[field]
);

const value = useMemo(() => {
return field.value.value;
}, [field.value]);

const isDisabled = useMemo(() => {
return !field.value.isEnabled;
}, [field.value]);

return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="loraWeight">
<FormLabel>{t('lora.weight')}</FormLabel>
</InformationalPopover>
<SettingToggle control={props.control} name="weight" />
</Flex>

<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.sliderMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.sliderMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
marks={DEFAULT_LORA_WEIGHT_CONFIG.marks.slice()}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});

DefaultWeight.displayName = 'DefaultWeight';
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
import { useLoRAModelDefaultSettings } from 'features/modelManagerV2/hooks/useLoRAModelDefaultSettings';
import { DefaultWeight } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeight';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { PiCheckBold } from 'react-icons/pi';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';

export type LoRAModelDefaultSettingsFormData = {
weight: FormField<number>;
};

type Props = {
modelConfig: LoRAModelConfig;
};

export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();

const defaultSettingsDefaults = useLoRAModelDefaultSettings(modelConfig);

const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();

const { handleSubmit, control, formState, reset } = useForm<LoRAModelDefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});

const onSubmit = useCallback<SubmitHandler<LoRAModelDefaultSettingsFormData>>(
(data) => {
const body = {
weight: data.weight.isEnabled ? data.weight.value : null,
};

updateModel({
key: modelConfig.key,
body: { default_settings: body },
})
.unwrap()
.then((_) => {
toast({
id: 'DEFAULT_SETTINGS_SAVED',
title: t('modelManager.defaultSettingsSaved'),
status: 'success',
});
reset(data);
})
.catch((error) => {
if (error) {
toast({
id: 'DEFAULT_SETTINGS_SAVE_FAILED',
title: `${error.data.detail} `,
status: 'error',
});
}
});
},
[updateModel, modelConfig.key, t, reset]
);

return (
<>
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
<Heading fontSize="md">{t('modelManager.defaultSettings')}</Heading>
<Button
size="sm"
leftIcon={<PiCheckBold />}
colorScheme="invokeYellow"
isDisabled={!formState.isDirty}
onClick={handleSubmit(onSubmit)}
isLoading={isLoadingUpdateModel}
>
{t('common.save')}
</Button>
</Flex>

<SimpleGrid columns={2} gap={8}>
<DefaultWeight control={control} name="weight" />
</SimpleGrid>
</>
);
});

LoRAModelDefaultSettings.displayName = 'LoRAModelDefaultSettings';
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
import { ControlAdapterModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlAdapterModelDefaultSettings/ControlAdapterModelDefaultSettings';
import { LoRAModelDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/LoRAModelDefaultSettings';
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
Expand Down Expand Up @@ -79,9 +80,13 @@ export const ModelView = memo(({ modelConfig }: Props) => {
{(modelConfig.type === 'controlnet' ||
modelConfig.type === 't2i_adapter' ||
modelConfig.type === 'control_lora') && <ControlAdapterModelDefaultSettings modelConfig={modelConfig} />}
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && (
<TriggerPhrases modelConfig={modelConfig} />
{modelConfig.type === 'lora' && (
<>
<LoRAModelDefaultSettings modelConfig={modelConfig} />
<TriggerPhrases modelConfig={modelConfig} />
</>
)}
{modelConfig.type === 'main' && <TriggerPhrases modelConfig={modelConfig} />}
</Box>
)}
<Box overflowY="auto" layerStyle="second" borderRadius="base" p={4}>
Expand Down
Loading