Skip to content

Commit

Permalink
修改数据集添加和显示列
Browse files Browse the repository at this point in the history
  • Loading branch information
data-infra committed Jun 26, 2024
1 parent f744e78 commit ea2cde2
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 51 deletions.
21 changes: 16 additions & 5 deletions myapp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def add_project(project_type, name, describe, expand={}):
add_project('job-template', __('数据预处理'), __('结构化话数据特征处理'), {"index": 3})
add_project('job-template', __('数据处理工具'), __('数据的单机或分布式处理任务,ray/spark/hadoop/volcanojob'), {"index": 4})
add_project('job-template', __('特征处理'), __('特征处理相关功能'), {"index": 5})
add_project('job-template', __('图像处理'), __('图像处理相关功能'), {"index": 5.1})
add_project('job-template', __('视频处理'), __('视频处理相关功能'), {"index": 5.2})
add_project('job-template', __('音频处理'), __('音频处理相关功能'), {"index": 5.3})
add_project('job-template', __('文本处理'), __('文本处理相关功能'), {"index": 5.4})
add_project('job-template', __('机器学习框架'), __('传统机器学习框架,sklearn'), {"index": 6})
add_project('job-template', __('机器学习算法'), __('传统机器学习,lr/决策树/gbdt/xgb/fm等'), {"index": 7})
add_project('job-template', __('深度学习'), __('深度框架训练,tf/pytorch/mxnet/mpi/horovod/kaldi等'), {"index": 8})
Expand Down Expand Up @@ -452,7 +456,7 @@ def create_dataset(**kwargs):
dataset = Dataset()
dataset.name = kwargs['name']
dataset.field = kwargs.get('field', '')
dataset.version = 'latest'
dataset.version = kwargs.get('version', 'latest')
dataset.label = kwargs.get('label', '')
dataset.status = kwargs.get('status', '')
dataset.describe = kwargs.get('describe', '')
Expand All @@ -472,7 +476,8 @@ def create_dataset(**kwargs):
dataset.storage_class = kwargs.get('storage_class', '')
dataset.storage_size = kwargs.get('storage_size', '')
dataset.download_url = kwargs.get('download_url', '')
dataset.owner = 'admin'
dataset.owner = kwargs.get('owner', 'admin')
dataset.features = kwargs.get('features', '{}')
dataset.created_by_fk = 1
dataset.changed_by_fk = 1
db.session.add(dataset)
Expand Down Expand Up @@ -631,7 +636,7 @@ def create_inference(project_name, service_name, service_describe, image_name, c
from myapp.views.view_inferenceserving import InferenceService_ModelView_base
inference_class = InferenceService_ModelView_base()
inference_class.src_item_json = {}
inference_class.pre_add(service)
inference_class.use_expand(service)

db.session.add(service)
db.session.commit()
Expand Down Expand Up @@ -756,10 +761,13 @@ def add_chat(chat_path):
if not chat.id:
db.session.add(chat)
db.session.commit()
print(f'add chat {name} success')
except Exception as e:
print(e)
# traceback.print_exc()

# 添加chat
# if conf.get('BABEL_DEFAULT_LOCALE','zh')=='zh':
try:
print('begin add chat')
init_file = os.path.join(init_dir, 'init-chat.json')
Expand All @@ -768,7 +776,7 @@ def add_chat(chat_path):
except Exception as e:
print(e)
# traceback.print_exc()
# 添加chat

# if conf.get('BABEL_DEFAULT_LOCALE','zh')=='zh':
try:
SQLALCHEMY_DATABASE_URI = os.getenv('MYSQL_SERVICE', '')
Expand Down Expand Up @@ -819,6 +827,7 @@ def add_chat(chat_path):
# traceback.print_exc()
# 添加ETL pipeline
try:
print('begin add etl pipeline')
from myapp.models.model_etl_pipeline import ETL_Pipeline
tables = db.session.query(ETL_Pipeline).all()
if len(tables) == 0:
Expand All @@ -840,6 +849,7 @@ def add_chat(chat_path):

# 添加nni超参搜索
try:
print('begin add nni')
from myapp.models.model_nni import NNI
nni = db.session.query(NNI).all()
if len(nni) == 0:
Expand All @@ -862,14 +872,15 @@ def add_chat(chat_path):
resource_gpu=nni.get('resource_gpu', '0'),
))
db.session.commit()
print('添加etl pipeline成功')
print('添加nni 超参搜索成功')
except Exception as e:
print(e)
# traceback.print_exc()


# 添加镜像在线构建
try:
print('begin add docker')
from myapp.models.model_docker import Docker
docker = db.session.query(Docker).all()
if len(docker) == 0:
Expand Down
22 changes: 22 additions & 0 deletions myapp/models/model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,37 @@ def url_html(self):
html+='<a target=_blank href="%s">%s</a><br>'%(url.strip(),url.strip())
return Markup('<div>%s</div>'%html)

def label_html(self):
urls = self.url.split('\n') if self.url else []
urls = [url.strip() for url in urls if url.strip()]
if urls:
url = urls[0]
return Markup('<a target=_blank href="%s">%s</a>'%(url.strip(), self.label))
return self.label

@property
def path_html(self):
paths= self.path.split('\n')

html = ''
for path in paths:
exist_file=False
if path.strip():
host_path = path.replace('/mnt/','/data/k8s/kubeflow/pipeline/workspace/').strip()
if os.path.exists(host_path):
if os.path.isdir(host_path):
data_csv_path = os.path.join(host_path,'data.csv')
if os.path.exists(data_csv_path):
path = os.path.join(path,'data.csv')
exist_file = True
else:
exist_file=True
if exist_file:
download_url = request.host_url+'/static/'+path.replace('/data/k8s/kubeflow/','')
html += f'<a target=_blank href="{download_url}">{path.strip()}</a><br>'
else:
html += f'{path.strip()}<br>'

return Markup('<div>%s</div>'%html)


Expand Down
2 changes: 1 addition & 1 deletion myapp/models/model_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def name_url(self):
# url= url + "#"+self.mount

# 对于有边缘节点,直接使用边缘集群的代理ip
if SERVICE_EXTERNAL_IP:
if SERVICE_EXTERNAL_IP and conf.get('ENABLE_EDGE_K8S',False):
SERVICE_EXTERNAL_IP = SERVICE_EXTERNAL_IP.split('|')[-1].strip()
from myapp.utils import core
meet_ports = core.get_not_black_port(10000 + 10 * self.id)
Expand Down
19 changes: 13 additions & 6 deletions myapp/models/model_train_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

import os.path
import re
from flask_appbuilder import Model
from sqlalchemy.orm import relationship
from sqlalchemy import Text
Expand All @@ -14,7 +15,7 @@
from flask import Markup
metadata = Model.metadata
conf = app.config

import pysnooper

class Training_Model(Model,AuditMixinNullable,MyappModelBase):
__tablename__ = 'model'
Expand Down Expand Up @@ -59,11 +60,17 @@ def project_url(self):

@property
def deploy(self):
download_url = ''
if self.path or self.download_url:
download_url = f'{__("下载")} |'
if self.download_url and self.download_url.strip():
download_url = f'<a href="/training_model_modelview/api/download/{self.id}">{__("下载")}</a> |'
else:
download_url = f'{__("下载")} |'
if self.path and self.path.strip():
if re.match('^/mnt/', self.path):
local_path = f'/data/k8s/kubeflow/pipeline/workspace/{self.path.strip().replace("/mnt/","")}'
if os.path.exists(local_path):
download_url = f'<a href="/training_model_modelview/api/download/{self.id}">{__("下载")}</a> |'
if 'http://' in self.path or 'https://' in self.path:
download_url = f'<a href="/training_model_modelview/api/download/{self.id}">{__("下载")}</a> |'

ops=download_url+f'''
<a href="/training_model_modelview/api/deploy/{self.id}">{__("发布")}</a>
'''
Expand Down
82 changes: 43 additions & 39 deletions myapp/views/view_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import re
import shutil

import zipfile, pandas
from flask_appbuilder import action
from myapp.views.baseSQLA import MyappSQLAInterface as SQLAInterface
from wtforms.validators import DataRequired, Regexp
Expand Down Expand Up @@ -29,7 +29,7 @@
from flask_appbuilder import expose
from myapp.views.view_team import Project_Join_Filter, filter_join_org_project
from myapp.models.model_dataset import Dataset

from myapp.utils import core
conf = app.config


Expand Down Expand Up @@ -57,9 +57,7 @@ class Dataset_ModelView_base():
order_columns = ['id']
base_filters = [["id", Dataset_Filter, lambda: []]] # 设置权限过滤器

add_columns = ['name', 'version', 'label', 'describe', 'source_type', 'source', 'field',
'usage', 'storage_class', 'file_type', 'url', 'download_url', 'path',
'storage_size', 'entries_num', 'duration', 'price', 'status', 'icon', 'owner', 'features']
add_columns = ['name', 'version', 'label', 'describe', 'url', 'download_url', 'path', 'icon', 'owner', 'features']
show_columns = ['id', 'name', 'version', 'label', 'describe', 'segment', 'source_type', 'source',
'industry', 'field', 'usage', 'storage_class', 'file_type', 'status', 'url',
'path', 'download_url', 'storage_size', 'entries_num', 'duration', 'price', 'status', 'icon',
Expand All @@ -75,25 +73,26 @@ class Dataset_ModelView_base():
"years": _("数据年份"),
"url": _("相关网址"),
"url_html": _("相关网址"),
"label_html": _("中文名"),
"path": _("本地路径"),
"path_html": _("本地路径"),
"entries_num": _("条目数量"),
"duration": _("文件时长"),
"price": _("价格"),
"icon": _("示例图"),
"icon_html": _("示例图"),
"icon": _("预览图"),
"icon_html": _("预览图"),
"ops_html": _("操作"),
"features": _("特征列"),
"segment": _("分区")
}

edit_columns = add_columns
list_columns = ['icon_html', 'name', 'version', 'label', 'describe','owner', 'source_type', 'source', 'status',
'field', 'url_html', 'download_url_html', 'usage', 'storage_class', 'file_type', 'path_html', 'storage_size', 'entries_num', 'price']
list_columns = ['icon_html', 'name', 'version', 'label_html', 'describe','owner', 'ops_html', 'path_html', 'download_url_html']

cols_width = {
"name": {"type": "ellip1", "width": 200},
"name": {"type": "ellip1", "width": 150},
"label": {"type": "ellip2", "width": 200},
"label_html": {"type": "ellip2", "width": 200},
"version": {"type": "ellip2", "width": 100},
"describe": {"type": "ellip2", "width": 300},
"field": {"type": "ellip1", "width": 100},
Expand All @@ -118,33 +117,30 @@ class Dataset_ModelView_base():
"ops_html": {"type": "ellip1", "width": 200},
}
features_demo = '''
填写规则
{
"column1": {
# feature type
"type": "dict,list,tuple,Value,Sequence,Array2D,Array3D,Array4D,Array5D,Translation,TranslationVariableLanguages,Audio,Image,Video,ClassLabel",
"_type": "dict,list,tuple,Value,Sequence,Array2D,Array3D,Array4D,Array5D,Translation,TranslationVariableLanguages,Audio,Image,Video",
# data type in dict,list,tuple,Value,Sequence,Array2D,Array3D,Array4D,Array5D
"dtype": "null,bool,int8,int16,int32,int64,uint8,uint16,uint32,uint64,float16,float32,float64,time32[(s|ms)],time64[(us|ns)],timestamp[(s|ms|us|ns)],timestamp[(s|ms|us|ns),tz=(tzstring)],date32,date64,duration[(s|ms|us|ns)],decimal128(precision,scale),decimal256(precision,scale),binary,large_binary,string,large_string"
# length of Sequence
"length": 10
# dimension of Array2D,Array3D,Array4D,Array5D
"shape": (1, 2, 3, 4, 5),
# sampling rate of Audio
"sampling_rate":16000,
"mono": true,
"decode": true
# decode of Image
"decode": true
# class of ClassLabel
"num_classes":3,
"names":['class1','class2','class3']
},
}
}
示例:
{
"id": {
"_type": "Value",
"dtype": "string"
},
"image": {
"_type": "Image"
},
"box": {
"_type": "Value",
"dtype": "string"
}
}
'''
add_form_extra_fields = {
Expand All @@ -160,7 +156,7 @@ class Dataset_ModelView_base():
description= _('数据集版本'),
default='latest',
widget=BS3TextFieldWidget(),
validators=[DataRequired(), Regexp("^[a-z][a-z0-9_\-]*[a-z0-9]$"), ]
validators=[DataRequired(), Regexp("[a-z0-9_\-]*"), ]
),
"subdataset": StringField(
label= _('子数据集'),
Expand Down Expand Up @@ -255,19 +251,26 @@ class Dataset_ModelView_base():
),
"path": StringField(
label= _('本地路径'),
description='',
description='本地文件通过notebook上传到平台内,处理后,压缩成单个压缩文件,每行一个压缩文件地址',
widget=MyBS3TextAreaFieldWidget(rows=3),
default=''
),
"download_url": StringField(
label= _('下载地址'),
description='',
description='可以直接下载的链接地址,每行一个url',
widget=MyBS3TextAreaFieldWidget(rows=3),
default=''
),
"icon": StringField(
label=_('预览图'),
default='',
description=_('可以为图片地址,svg源码,或者帮助文档链接'),
widget=BS3TextFieldWidget(),
validators=[]
),
"features": StringField(
label= _('特征列'),
description= _('数据集中的列信息'),
description= _('数据集中的列信息,要求数据集中要有data.csv文件用于表示数据集中的全部数据'),
widget=MyBS3TextAreaFieldWidget(rows=3, tips=Markup('<pre><code>' + features_demo + "</code></pre>")),
default=''
)
Expand All @@ -280,13 +283,14 @@ class Dataset_ModelView_base():
def pre_add(self, item):
if not item.owner:
item.owner = g.user.username + ",*"
if not item.icon:
item.icon = '/static/assets/images/dataset.png'
if item.icon and '</svg>' in item.icon:
item.icon = re.sub(r'width="\d+(\.\d+)?(px)?"', f'width="50px"', item.icon)
item.icon = re.sub(r'height="\d+(\.\d+)?(px)?"', f'height="50px"', item.icon)
if not item.version:
item.version = 'latest'
if not item.subdataset:
item.subdataset = item.name

item.features = json.dumps(json.loads(item.features),indent=4,ensure_ascii=False) if item.features else "{}"
def pre_update(self, item):
self.pre_add(item)

Expand Down Expand Up @@ -405,15 +409,15 @@ def path2url(path):
dataset = db.session.query(Dataset).filter_by(id=int(dataset_id)).first()
try:
download_url = []
if dataset.path:
if dataset.path and dataset.path.strip():
# 如果存储在集群数据集中心
# 如果存储在个人目录
paths = dataset.path.split('\n')
for path in paths:
download_url.append(path2url(path))

# 如果存储在外部链接
elif dataset.download_url:
elif dataset.download_url and dataset.download_url.strip():
download_url = dataset.download_url.split('\n')
else:
# 如果存储在对象存储中
Expand Down

0 comments on commit ea2cde2

Please sign in to comment.